diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 97cf467cca07..e69de29bb2d1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,158 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Github code owners file -# This file is used as a convenient tool to map -# committers' areas of expertise and faciliate the review process. -# -# This may not be the non-comprehensive list and is meant to be -# updated over time. - -# Per ASF policy, committer have global write permission. -# We normally recommend committers to shepherd code in their area of expertise. -* @apache/tvm-committers - -# Order is important; the last matching pattern takes the most precedence. -# The sub modules should be ordered first by depth. -# Making sure we append new sub-module rules after exisiting modules rules. - -############################## -# Top-level Fallbacks -############################## -include/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -src/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -apps/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics -python/** @tqchen @jroesch @yzhliu @icemelon @junrushao1994 @comaniac @zhiics - -# Thirdparty license audit -3rdparty/** @tqchen @jroesch -licenses/** @tqchen @jroesch - -# JVM language -jvm/** @yzhliu - -# Golang -golang/** @srkreddy1238 - -# WASM -web/** @tqchen @jroesch - -# Docker -docker/** @areusch @leandron @jroesch - -# Conda -conda/** @tqchen @junrushao1994 @comaniac - -# CMake -cmake/** @jroesch @tqchen @areusch @junrushao1994 @comaniac - -# rust bindings -rust/** @jroesch @nhynes @nhynes - -# vta -vta/** @tmoreau89 @vegaluisjose - -# docs -docs/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon -tutorials/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon - -# tests -tests/** @comaniac @junrushao1994 @tqchen @jroesch @areusch @yzhliu @merrymercy @icemelon - -############################## -# Specific modules -############################## - -# automation related -src/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy -include/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy -python/tvm/auto_scheduler/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 @Hzfengsy - -python/tvm/autotvm/** @merrymercy @jcf94 @comaniac @junrushao1994 @vinx13 - -# node system and reflection -src/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -include/tvm/node/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac - -# ir: Common IR -src/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -include/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac -python/tvm/ir/** @junrushao1994 @vinx13 @tqchen @jroesch @comaniac - -# tir -src/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy -include/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy -python/tvm/tir/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were @Hzfengsy - -# te -src/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were -include/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were -python/tvm/te/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi @were - -# target -src/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi -include/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi -python/tvm/target/** @junrushao1994 @vinx13 @tqchen @kparzysz-quic @ZihengJiang @masahi - -# arith: Arithmetic module and simplifiers -src/arith/** @tqchen @junrushao1994 @vinx13 -include/tvm/arith/** @tqchen @junrushao1994 @vinx13 -python/tvm/arith/** @tqchen @junrushao1994 @vinx13 - -# parser -src/parser/** @jroesch @slyubomirsky - -# runtime -src/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 -include/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 -python/tvm/runtime/** @vinx13 @tqchen @FronzenGene @liangfu @areusch @tmoreau89 @ajtulloch @masahi @kazum @ZihengJiang @junrushao1994 - -# runtime/micro -src/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm -src/runtime/crt/** @areusch @liangfu @tmoreau89 @manupa-arm -include/tvm/runtime/crt/** @areusch @liangfu @tmoreau89 @manupa-arm -include/tvm/runtime/micro/** @areusch @liangfu @tmoreau89 @manupa-arm -python/tvm/micro/** @areusch @liangfu @tmoreau89 @manupa-arm - -# relay -src/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -include/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 -python/tvm/relay/** @jroesch @slyubomirsky @icemelon @MarisaKirisame @ZihengJiang @yzhliu @vinx13 @mbrookhart @jwfromm @zhiics @anijain2305 @wweic @eqy @junrushao1994 - - -# relay/qnn -src/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang -inlcude/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang -python/tvm/relay/qnn/** @jwfromm @anijain2305 @ZihengJiang - -# relay/backend/contrib: BYOC -src/relay/backend/contrib/** @zhiics @trevor-m @comaniac @mbaret @manupa-arm - -# relay/frontends -python/tvm/relay/frontend/** @jwfromm @mbrookhart @srkreddy1238 @siju-samuel @Huyuwei @hlu1 @kazum @PariksheetPinjari909 - -# topi: Operator definitions -src/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -include/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 -python/tvm/topi/** @Laurawly @Huyuwei @kevinthesun @jwfromm @vinx13 @masahi @FronzenGene @yzhliu @mbrookhart @ZihengJiang @jcf94 - - -# tvm/driver/ -python/tvm/driver/** @leandron @jwfromm @tqchen @jroesch - -# tvm/driver/tvmc -python/tvm/driver/tvmc/** @leandron @jwfromm diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index f55a0651a870..21683a5380bb 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -93,6 +93,10 @@ class IntSet : public ObjectRef { bool CanProveNonPositive() const; /*! \return Whether the set is proved to be larger than or equal to 0 */ bool CanProveNonNegative() const; + /*! \return Whether the set has upper bound. */ + bool HasUpperBound() const; + /*! \return Whether the set has lower bound. */ + bool HasLowerBound() const; /*! * \brief The single point value, call only if IsSinglePoint is true * \return The point value. @@ -164,6 +168,14 @@ Map ConvertDomMap(const std::unordered_map& * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const Map& dom_map); +/*! + * \brief Same as EvalSet, but takes Map + * + * \param e The expression to be evaluated. + * \param dom_map The domain of each variable. + * \return An integer set that can cover all the possible values of e. + */ +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -172,6 +184,15 @@ IntSet EvalSet(PrimExpr e, const Map& dom_map); * \return An integer set that can cover all the possible values of e. */ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); +/*! + * \brief Same as EvalSet, but takes Array + * + * \param exprs The expressions to be evaluated. + * \param dom_map The domain of each variable. + * \return An array of integer sets that can cover all the possible values. + */ +Array EvalSet(const Array& exprs, const Map& dom_map); + /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 22b4cd580e18..f2c6a54c93e7 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -350,6 +350,8 @@ Array> SubspaceDivide(const Array& bindings, bool require_bijective, arith::Analyzer* analyzer, DiagnosticContext diag_ctx); +PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index d7b23dd79c25..2b809459155e 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -32,7 +32,7 @@ class BuilderInputNode : public runtime::Object { IRModule mod; /*! \brief The target to be built for. */ Target target; - /*! \brief The optional parameters used for build */ + /*! \brief Parameters for Relay build module. */ Optional> params; void VisitAttrs(tvm::AttrVisitor* v) { @@ -55,7 +55,7 @@ class BuilderInput : public runtime::ObjectRef { * \brief Constructor of BuilderInput. * \param mod The IRModule to be built. * \param target The target to be built for. - * \param params The optional parameters used for build + * \param params Parameters for Relay build module. */ TVM_DLL explicit BuilderInput(IRModule mod, Target target, Optional> params = NullOpt); diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index f07d8e136644..307ec309c009 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -237,6 +237,7 @@ class PyDatabaseNode : public DatabaseNode { // PackedFuncs are all not visited, because the reflection system doesn't take care of them, // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. + // // `f_has_workload` is not visited // `f_commit_workload` is not visited // `f_commit_tuning_record` is not visited diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h new file mode 100644 index 000000000000..bd4e4e9c37d3 --- /dev/null +++ b/include/tvm/meta_schedule/mutator.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MUTATOR_H_ +#define TVM_META_SCHEDULE_MUTATOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Mutator is designed to mutate the trace to explore the design space. */ +class MutatorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MutatorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \param rand_state The random state for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + virtual Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) = 0; + + static constexpr const char* _type_key = "meta_schedule.Mutator"; + TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); +}; + +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + using FApply = runtime::TypedPackedFunc( + const tir::Trace&, support::LinearCongruentialEngine::TRandState rand_state)>; + /*! + * \brief Get the mutator as string with name. + * \return The string of the mutator. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Optional Apply(const tir::Trace& trace, + support::LinearCongruentialEngine::TRandState* rand_state) final { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return this->f_apply(trace, *rand_state); + } + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { + public: + /*! \brief Create a Mutator that mutates the tile size. */ + TVM_DLL static Mutator MutateTileSize(); + /*! + * \brief Create a Mutator that mutates the parallel extent + * \param max_jobs_per_core The maximum number of parallel jobs per core. + * \return The created mutator. + */ + TVM_DLL static Mutator MutateParallel(int64_t max_jobs_per_core); + /*! \brief Create a Mutator that mutates auto unroll step */ + TVM_DLL static Mutator MutateUnroll(); + /*! + * \brief Create a Mutator that mutates the outcome of SampleComputeLocation + * \return The mutator created + */ + TVM_DLL static Mutator MutateComputeLocation(); + /*! + * \brief Create a mutator with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The mutator created. + */ + TVM_DLL static Mutator PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MUTATOR_H_ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h new file mode 100644 index 000000000000..b1cb14cee725 --- /dev/null +++ b/include/tvm/meta_schedule/postproc.h @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_POSTPROC_H_ +#define TVM_META_SCHEDULE_POSTPROC_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! + * \brief Rules to apply a postprocessor to a schedule. + */ +class PostprocNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~PostprocNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Initialize the design space generator with tuning context. + * \param tune_context The tuning context for initialization. + * \note This method is supposed to be called only once before every other method. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a postprocessor to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the postprocessor was successfully applied. + */ + virtual bool Apply(const tir::Schedule& sch) = 0; + + static constexpr const char* _type_key = "meta_schedule.Postproc"; + TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); +}; + +/*! \brief The postprocessor with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply a postprocessor to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the postprocessor was successfully applied. + */ + using FApply = runtime::TypedPackedFunc; + /*! + * \brief Get the postprocessor function as string with name. + * \return The string of the postprocessor function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` function. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` function. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` function. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + bool Apply(const tir::Schedule& sch) final { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return this->f_apply(sch); + } + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { + public: + /*! + * \brief Create a postprocessor with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \param f_as_string The packed function of `AsString`. + * \return The postprocessor created. + */ + TVM_DLL static Postproc PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string); + /*! + * \brief Create a postprocessor that checks if all loops are static + * \return The postprocessor created + */ + TVM_DLL static Postproc DisallowDynamicLoop(); + /*! + * \brief Create a postprocessor that rewrites the cooperative fetch annotation to + * actual vectorized cooperative fetching in loop bindings. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteCooperativeFetch(); + /*! + * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling + * according to the annotation of each block + * \return The postprocessor created + */ + TVM_DLL static Postproc RewriteParallelVectorizeUnroll(); + /*! + * \brief Create a postprocessor that rewrites reduction block by moving the init block out. + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteReductionBlock(); + /*! + * \brief Create a postprocessor that adds thread binding to unbound blocks + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteUnboundBlock(); + /*! + * \brief Create a postprocessor that tensorize Tensor Core related components + * \return The postprocessor created. + */ + TVM_DLL static Postproc RewriteTensorCore(); + + /*! + * \brief Creates a postprocessor that verifies if the GPU code is correct + * \return The postprocessor created + */ + TVM_DLL static Postproc VerifyGPUCode(); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_POSTPROC_H_ diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 8313da067f09..449c6cf7e4cf 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -141,7 +141,7 @@ class ScheduleRule : public runtime::ObjectRef { * - [blockIdx.x, vthread.x, threadIdx.x] on GPU * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. * NullOpt means disable vectorization * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. @@ -151,9 +151,26 @@ class ScheduleRule : public runtime::ObjectRef { Optional> tile_binds, // bool use_tensor_core, // Optional max_innermost_factor, // - Optional vector_load_max_len, // + Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write); + /*! + * \brief Create a rule: add-rfactor to some blocks if needed + * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the + * uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. Use -1 to disable + * parallelism. + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // + Optional max_innermost_factor); + /*! + * \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks + * correspondingly when needed + * \param thread_extents Candidates of thread axis extent (values are required to be positive). + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The rule created diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index e1c68c8a1a11..af128a4b60d3 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -28,6 +28,8 @@ namespace meta_schedule { // Forward declaration class TuneContext; +class CostModel; +class Database; /*! \brief The schedule (with input shapes) to be measured. */ class MeasureCandidateNode : public runtime::Object { @@ -135,7 +137,9 @@ class SearchStrategyNode : public runtime::Object { * \brief Update the search strategy with measurement results. * \param results The measurement results from the runner. */ - virtual void NotifyRunnerResults(const Array& results) = 0; + virtual void NotifyRunnerResults(const TuneContext& tune_context, + const Array& measure_candidates, + const Array& results) = 0; static constexpr const char* _type_key = "meta_schedule.SearchStrategy"; TVM_DECLARE_BASE_OBJECT_INFO(SearchStrategyNode, Object); @@ -165,7 +169,8 @@ class PySearchStrategyNode : public SearchStrategyNode { * \brief The function type of `NotifyRunnerResults` method. * \param results The measurement results from the runner. */ - using FNotifyRunnerResults = runtime::TypedPackedFunc&)>; + using FNotifyRunnerResults = runtime::TypedPackedFunc&, const Array&)>; /*! \brief The packed function to the `InitializeWithTuneContext` method. */ FInitializeWithTuneContext f_initialize_with_tune_context; @@ -208,10 +213,12 @@ class PySearchStrategyNode : public SearchStrategyNode { return this->f_generate_measure_candidates(); } - void NotifyRunnerResults(const Array& results) final { + void NotifyRunnerResults(const TuneContext& tune_context, + const Array& measure_candidates, + const Array& results) final { ICHECK(f_notify_runner_results != nullptr) << "PySearchStrategy's NotifyRunnerResults method not implemented!"; - this->f_notify_runner_results(results); + this->f_notify_runner_results(tune_context, measure_candidates, results); } static constexpr const char* _type_key = "meta_schedule.PySearchStrategy"; @@ -247,6 +254,35 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + /*! + * \brief Constructor of replay func search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for func replaying. + */ + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + + /*! + * \brief Constructor of evolutionary search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for evolutionary search. + * \param population_size The initial sample population. + * \param init_measured_ratio The ratio of measures samples in initial population. + * \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling. + * \param genetic_num_iters The iterations to run the genetic algorithm. + * \param genetic_mutate_prob The probability of mutation. + * \param genetic_max_fail_count The maximum number to try evolving the given trace. + * \param eps_greedy The ratio to select samples in a greedy fashion via their predicted score. + */ + TVM_DLL static SearchStrategy EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population_size, // + double init_measured_ratio, // + int init_min_unmeasured, // + int genetic_num_iters, // + double genetic_mutate_prob, // + int genetic_max_fail_count, // + double eps_greedy); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index f28c33dc4fe4..0284a55e0d03 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -20,7 +20,9 @@ #define TVM_META_SCHEDULE_TASK_SCHEDULER_H_ #include +#include #include +#include #include #include @@ -78,7 +80,7 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief The list of measure callbacks of the scheduler. */ Array measure_callbacks; - /*! \brief The default desctructor. */ + /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; void VisitAttrs(tvm::AttrVisitor* v) { @@ -248,15 +250,19 @@ class TaskScheduler : public runtime::ObjectRef { * \param runner The runner of the scheduler. * \param database The database of the scheduler. */ - TVM_DLL static TaskScheduler RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database); // + TVM_DLL static TaskScheduler RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Optional cost_model, // + Optional> measure_callbacks); TVM_DLL static TaskScheduler PyTaskScheduler( Array tasks, // Builder builder, // Runner runner, // Database database, // + Optional cost_model, // + Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 6eacd4d4f12a..ff3a14c076e4 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,7 +20,12 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include +#include +#include +#include #include +#include #include #include #include @@ -28,6 +33,8 @@ namespace tvm { namespace meta_schedule { +class TaskSchedulerNode; + /*! \brief The auto tuning context. */ class TuneContextNode : public runtime::Object { public: @@ -41,19 +48,27 @@ class TuneContextNode : public runtime::Object { Optional search_strategy; /*! \brief The schedule rules. */ Array sch_rules; + /*! \brief The postprocessors. */ + Array postprocs; + /*! \brief The probability of using certain mutator. */ + Map mutator_probs; /*! \brief The name of the tuning task. */ - Optional task_name; + String task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ int num_threads; + /*! \brief The task scheduler that owns the tune context */ + const TaskSchedulerNode* task_scheduler; /*! \brief Whether the tuning task has been stopped or finished. */ bool is_stopped; - /*! \brief Packed functions to fetch the runner results asynchronously. */ - Optional> runner_futures; /*! \brief The measure candidates. */ Optional> measure_candidates; + /*! \brief The building results. */ + Optional> builder_results; + /*! \brief Packed functions to fetch the runner results asynchronously. */ + Optional> runner_futures; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("mod", &mod); @@ -61,14 +76,20 @@ class TuneContextNode : public runtime::Object { v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); v->Visit("sch_rules", &sch_rules); + v->Visit("postprocs", &postprocs); + v->Visit("mutator_probs", &mutator_probs); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); v->Visit("is_stopped", &is_stopped); + v->Visit("builder_results", &builder_results); v->Visit("runner_futures", &runner_futures); v->Visit("measure_candidates", &measure_candidates); } + /*! \brief Initialize members that needs initialization with tune context. */ + void Initialize(); + static constexpr const char* _type_key = "meta_schedule.TuneContext"; TVM_DECLARE_FINAL_OBJECT_INFO(TuneContextNode, Object); }; @@ -86,6 +107,8 @@ class TuneContext : public runtime::ObjectRef { * \param space_generator The design space generator. * \param search_strategy The search strategy. * \param sch_rules The schedule rules. + * \param postprocs The postprocessors. + * \param mutator_probs The probability of using certain mutator. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -95,6 +118,8 @@ class TuneContext : public runtime::ObjectRef { Optional space_generator, // Optional search_strategy, // Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index fcd2326050ed..89b1e9117ff4 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -29,6 +29,7 @@ #include #include // for uint64_t +#include namespace tvm { namespace support { @@ -73,6 +74,12 @@ class LinearCongruentialEngine { */ static constexpr result_type max() { return modulus - 1; } + /*! + * \brief Get a device random state + * \return The random state + */ + static TRandState DeviceRandom() { return (std::random_device()()) % modulus; } + /*! * \brief Operator to move the random state to the next and return the new random state. According * to definition of linear congruential engine, the new random state value is computed as @@ -93,6 +100,7 @@ class LinearCongruentialEngine { * \param rand_state The random state given in result_type. */ void Seed(TRandState rand_state = 1) { + ICHECK(rand_state != -1) << "The seed can't be -1 which should be changed to random seed!"; rand_state %= modulus; // Make sure the seed is within the range of modulus. if (rand_state == 0) rand_state = 1; // Avoid getting all 0 given the current parameter set. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index e482a18c4a5b..ee03997d6729 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -187,6 +187,95 @@ class LinkedParam : public ObjectRef { TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode); }; +/*! \brief A mapping from multi-dimensional indices to another set of multi-dimensional indices */ +class IndexMapNode : public Object { + public: + /*! \brief The source indices */ + Array src_iters; + /*! \brief The target indices */ + Array tgt_iters; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("src_iters", &src_iters); + v->Visit("tgt_iters", &tgt_iters); + } + + /*! + * \brief Take `inputs` as the source indices and return the corresponding target indices. + * \param inputs The source indices. + * \return The target indices. + */ + Array Apply(const Array& inputs) const; + + /*! + * \brief Map a shape to the output space + * \param shape The shape in the source space + * \return The shape in the target space + */ + Array MapShape(const Array& shape) const; + + /*! + * \brief Convert to string representation in Python. + * \return The stringified lambda expression in Python. + */ + String ToPythonString() const; + + static constexpr const char* _type_key = "tir.IndexMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); +}; + +/*! + * \brief Managed reference to IndexMapNode. + * \sa IndexMapNode + */ +class IndexMap : public ObjectRef { + public: + /*! + * \brief Constructor. + * \param src_iters The source indices. + * \param tgt_iters The target indices. + */ + explicit IndexMap(Array src_iters, Array tgt_iters); + /*! + * \brief Create an index map from a packed function + * \param ndim The number of dimensions + * \param func The function to be applied + * \return The created index map + */ + static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func); + TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); +}; + +/*! + * \brief Tensor TensorIntrin for Tensorization + */ +class TensorIntrinNode : public Object { + public: + /*! \brief The function to describe the computation. */ + PrimFunc description; + /*! \brief The intrinsic function for lower-level implement. */ + PrimFunc implementation; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("description", &description); + v->Visit("implementation", &implementation); + } + + static constexpr const char* _type_key = "tir.TensorIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); +}; + +class TensorIntrin : public ObjectRef { + public: + TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func); + + TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func); + + TVM_DLL static TensorIntrin Get(String name); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode) +}; + /*! * \brief Specialize parameters of PrimFunc. * \param func The PrimFunc to be specialized. diff --git a/include/tvm/tir/schedule/instruction.h b/include/tvm/tir/schedule/instruction.h index 5a9e687dc8c7..1af5ab07e67c 100644 --- a/include/tvm/tir/schedule/instruction.h +++ b/include/tvm/tir/schedule/instruction.h @@ -121,6 +121,9 @@ class InstructionKindNode : public runtime::Object { // not visited: f_attrs_from_json } + /*! \brief Checks if the instruction kind is EnterPostproc */ + bool IsPostproc() const; + static constexpr const char* _type_key = "tir.InstructionKind"; TVM_DECLARE_FINAL_OBJECT_INFO(InstructionKindNode, runtime::Object); }; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 210ed53a7904..dc5e99faccb3 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -210,6 +210,14 @@ class ScheduleNode : public runtime::Object { */ virtual Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) = 0; + /*! + * \brief Sample a compute-at location of the given block + * \param block_rv The block whose compute-at location is to be sampled + * \param decision The sampling decision + * \return The sampled loop where the input block is to be computed at + */ + virtual LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! @@ -347,6 +355,11 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /******** Schedule: Data movement ********/ + virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -465,17 +478,36 @@ class ScheduleNode : public runtime::Object { */ virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; /******** Schedule: Blockize & Tensorize ********/ + /*! + * \brief Make subtree rooted by a specific loop into a block + * \param loop_rv The root of the subtree + * \return The new block + */ + virtual BlockRV Blockize(const LoopRV& loop_rv) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin + * \param loop_rv the loop/block to be tensorized + * \param intrin the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0; + /*! + * \brief Tensorize the computation enclosed by loop with tensor_intrin + * \param loop_rv The loop/block to be tensorized + * \param intrin_name Name of the tensor intrinsic + */ + virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0; + /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair - * \param loop_rv The loop to be annotated + * \param loop The loop to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0; /*! * \brief Annotate a block with a key value pair - * \param block_rv The block to be annotated + * \param loop The block to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ @@ -483,17 +515,32 @@ class ScheduleNode : public runtime::Object { const ObjectRef& ann_val) = 0; /*! * \brief Unannotate a loop's annotation with key ann_key - * \param loop_rv The loop to be unannotated + * \param loop The loop to be unannotated * \param ann_key The annotation key */ virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; /*! * \brief Unannotate a block's annotation with key ann_key - * \param block_rv The block to be unannotated + * \param loop The block to be unannotated * \param ann_key The annotation key */ virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; + /******** Schedule: Layout transformation ********/ + /*! + * \brief Apply a transformation represented by IndexMap to buffer + * \details The indices and the access region to the target buffer is transformed by the given + * index_map. The index_map is used to infer the new shape of the buffer. Buffer must be either + * a function parameter, or allocated in a block (it cannot be a buffer subregion created via + * 'match_buffer'). + * \param block_rv The block that accesses the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param is_write_index Whether the buffer_index is the index of the block's write region. + * \param index_map The transformation to apply. + */ + virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index, + const IndexMap& index_map) = 0; + /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 066496704e5f..d8726541aecd 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1224,7 +1224,7 @@ class BlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; -/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ +/*! \brief namespace of possible attributes in AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. /*! \brief Mark launching extent of thread, used by device API. */ @@ -1357,6 +1357,93 @@ constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_ */ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; +/*! + * \brief Mark that the block need to add predicate for block var bounds during lowering + */ +constexpr const char* require_block_var_bound_predicate = "require_bound_predicate"; + +/*! + * \brief Mark that the loop should be further skip and bound to environment threads to enable + * cooperative fetching. + */ +constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; + +/*! + * \brief Mark that the block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; + +/*! \brief Mark that tensor core is enabled in the PrimExpr */ +constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_low_inclusive = + "meta_schedule.thread_extent_low_inclusive"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_high_inclusive = + "meta_schedule.thread_extent_high_inclusive"; + +/*! + * \brief Mark a block as generated by cache_read or cache_write block. + * 0 means cache_read; 1 means cache_write. + * \sa meta_schedule_cache_type_read + * \sa meta_schedule_cache_type_write + */ +constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_read = 0; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_write = 1; + +/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ +constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; + +/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ +constexpr const char* meta_schedule_random_compute_producer = + "meta_schedule.random_compute_producer"; + +/*! \brief Mark auto-parallel setting on the block. */ +constexpr const char* meta_schedule_parallel = "meta_schedule.parallel"; + +/*! \brief Mark auto-vectorize setting on the block. */ +constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; + +/*! \brief Pragma: auto-unroll, max_step */ +constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"; + +/*! \brief Pragma: unroll explicit */ +constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit"; + +/*! \brief Mark the scope of the software pipeline */ +constexpr const char* software_pipeline_scope = "software_pipeline_scope"; + +/*! \brief Mark the stage of a statement in the software pipeline */ +constexpr const char* software_pipeline_stage = "software_pipeline_stage"; + +/*! \brief Mark the order of a statement in the software pipeline */ +constexpr const char* software_pipeline_order = "software_pipeline_order"; + +/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify + * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the + * prologue, the body, and the epilogue of the software pipeline. + */ +constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline_stage"; + +/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify + * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the + * prologue, the body, and the epilogue of the software pipeline. + */ +constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 0b4ace20078c..3ddebc5bf0f0 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -406,9 +406,12 @@ inline T Substitute(T input, const std::unordered_map& * \param stmt_or_expr The ir to be visited. * \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the * children of the node + * \param visit_init_block Whether or not to visit the init block + * children of the node */ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr, - const std::function& fvisit); + const std::function& fvisit, + bool visit_init_block = true); } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 7a6cfa364447..edd75998d757 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -383,6 +383,20 @@ TVM_DLL Pass LowerInitBlock(); */ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + /*! * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the * corresponding iter_values in BlockRealize, for opaque blocks by removing all @@ -484,6 +498,24 @@ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); */ TVM_DLL Pass ConvertForLoopsToSerial(); +/*! + * \brief Transform annotated loops into pipelined one that ovarlaps producers and consumers. + * \return The IR transform pass. + */ +TVM_DLL Pass InjectSoftwarePipeline(); + +/*! + * \brief Automatically do memory optimizations for auto copy blocks + * \return The pass. + */ +TVM_DLL Pass LowerAutoCopy(); + +/*! + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + * \return The pass. + */ +TVM_DLL Pass RenormalizeSplitPattern(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 1ac58e18db3e..07d74b9b6fb9 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -255,7 +255,7 @@ class IterVarNode : public Object { IterVarType iter_type; /*! * \brief additional tag on the iteration variable, - * set this if this is binded already to a known thread tag. + * set this if this is bound already to a known thread tag. */ String thread_tag; /*! diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index f1156998bdac..0e9c4abebbe1 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,7 +543,8 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, _ = load_best_record(log_file, self.workload_key) + inp, res = load_best_record(log_file, self.workload_key) + print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 885eb0d1d0f8..75702b0a21af 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,7 +194,10 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - return value(*args) + result = value(*args) + if isinstance(result, tuple): + result = list(result) + return result def serialize_workload_registry_entry(workload_key): diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 34823ebb1781..2bea0a5da6d9 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -231,6 +231,8 @@ def build( elif isinstance(inputs, PrimFunc): input_mod = lower(inputs, name=name) elif isinstance(inputs, tvm.IRModule): + if name is not None and name != "default_function": + warnings.warn("Specifying name with IRModule input is useless") input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 8b6672ccc371..b5ad51329ad9 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,8 +19,20 @@ from . import database from . import builder from . import runner +from . import mutator +from . import postproc +from . import schedule_rule from . import space_generator from . import search_strategy from . import schedule_rule from . import integration +from . import feature_extractor +from . import cost_model +from .search_strategy import ( + EvolutionarySearchConfig, + MeasureCandidate, + ReplayFuncConfig, + ReplayTraceConfig, +) +from .tune import tune_te, tune_tir, tune_relay from .tune_context import TuneContext diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index ef5d3ca130a7..5d658f0fec23 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" -from typing import Dict, List, Optional +from typing import List, Optional, Dict +from tvm.runtime import NDArray from tvm._ffi import register_object from tvm.ir import IRModule from tvm.runtime import NDArray, Object @@ -42,6 +43,7 @@ class BuilderInput(Object): mod: IRModule target: Target + params: Optional[Dict[str, NDArray]] def __init__( self, diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index 954da87e6a63..863aac4d0ee3 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -22,13 +22,28 @@ from tvm._ffi import register_func from tvm.ir import IRModule -from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict +from tvm.runtime import NDArray +from tvm.runtime import Module, load_param_dict, save_param_dict from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind from ..utils import cpu_count, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder +logger = logging.getLogger(__name__) + + +def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]: + if params is None: + return None + return save_param_dict(params) + + +def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]: + if params is None: + return None + return load_param_dict(params) + logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -127,9 +142,8 @@ def __init__( The initializer to be used for the worker processes. """ super().__init__() - if max_workers is None: - max_workers = cpu_count() + max_workers = cpu_count(logical=True) logger.info("LocalBuilder: max_workers = %d", max_workers) self.pool = PopenPoolExecutor( diff --git a/python/tvm/meta_schedule/cost_model/__init__.py b/python/tvm/meta_schedule/cost_model/__init__.py index 3d4a81e1222f..8fc6f04ac955 100644 --- a/python/tvm/meta_schedule/cost_model/__init__.py +++ b/python/tvm/meta_schedule/cost_model/__init__.py @@ -19,3 +19,4 @@ """ from .cost_model import CostModel, PyCostModel from .random_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f5bd60162ec5..13ca203c908b 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -15,17 +15,19 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule CostModel.""" -import ctypes + from typing import List +import ctypes + +import numpy as np -import numpy as np # type: ignore from tvm._ffi import register_object from tvm.runtime import Object from .. import _ffi_api from ..runner import RunnerResult -from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate from ..utils import _get_hex_address, check_override diff --git a/python/tvm/meta_schedule/cost_model/metric.py b/python/tvm/meta_schedule/cost_model/metric.py index efd8dc68ac0d..7eb6da6f07d9 100644 --- a/python/tvm/meta_schedule/cost_model/metric.py +++ b/python/tvm/meta_schedule/cost_model/metric.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. """Cost model metrics for meta schedule""" -import numpy as np # type: ignore +from typing import List +import numpy as np -def max_curve(trial_scores: np.ndarray) -> np.ndarray: +def max_curve(trial_scores: np.ndarray) -> List[float]: """f(n) = max([s[i] fo i < n]) Parameters @@ -28,8 +29,8 @@ def max_curve(trial_scores: np.ndarray) -> np.ndarray: Returns ------- - curve : np.ndarray - A vector, the max-curve function values + curve : List[float] + function values """ ret = np.empty(len(trial_scores)) keep = -1e9 diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py index 23238d25797c..56c65f64afa3 100644 --- a/python/tvm/meta_schedule/cost_model/random_model.py +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -17,14 +17,14 @@ """ Random cost model """ -from typing import List, Optional, Tuple, Union +from typing import List, Union, Tuple, Optional -import numpy as np # type: ignore +import numpy as np -from ..cost_model import PyCostModel from ..runner import RunnerResult -from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..cost_model import PyCostModel class RandomModel(PyCostModel): @@ -70,7 +70,7 @@ def load(self, path: str) -> None: path : str The file path. """ - self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore + self.random_state = tuple(np.load(path, allow_pickle=True)) def save(self, path: str) -> None: """Save the cost model to given file location. @@ -116,7 +116,7 @@ def predict(self, tune_context: TuneContext, candidates: List[MeasureCandidate]) The predicted running results. """ np.random.set_state(self.random_state) - # TODO(@zxybazh): Use numpy's RandState object: + # todo(@zxybazh): Use numpy's RandState object: # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState result = np.random.rand(len(candidates)) * self.max_range self.random_state = np.random.get_state() diff --git a/python/tvm/meta_schedule/cost_model/xgb_model.py b/python/tvm/meta_schedule/cost_model/xgb_model.py new file mode 100644 index 000000000000..6b833963f322 --- /dev/null +++ b/python/tvm/meta_schedule/cost_model/xgb_model.py @@ -0,0 +1,680 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +XGBoost-based cost model +""" +from itertools import chain as itertools_chain +import logging +import os +import tempfile +from typing import Any, Callable, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Tuple + +import numpy as np + +from ...contrib.tar import tar, untar +from ..cost_model import PyCostModel +from ..feature_extractor import FeatureExtractor +from ..runner import RunnerResult +from ..search_strategy import MeasureCandidate +from ..utils import cpu_count +from .metric import max_curve + +if TYPE_CHECKING: + from ..tune_context import TuneContext + import xgboost as xgb + + +logger = logging.getLogger(__name__) + + +def make_metric_sorter(focused_metric): + """ Make sure the focused metric is the first one. """ + + def metric_name_for_sort(name): + if focused_metric == name: + return "!" + name + return name + + def sort_key(key): + key, _ = key + return metric_name_for_sort(key) + + return sort_key + + +class PackSum: + """The pack-sum format + + Parameters + ---------- + dmatrix : xgb.DMatrix + A float64 array of shape [n, m], + where `n` is the packed number of blocks, + and `m` is the length of feature vector on each block + ids : np.ndarray + An int64 array of shape [n] containing nonnegative integers, + indicating which the index of a sample that a block belongs to + """ + + dmatrix: "xgb.DMatrix" # type: ignore # pylint: disable=invalid-name + ids: np.ndarray + + def __init__( + self, + xs: List[np.ndarray], + ys: Optional[np.ndarray], + ): + """Create PackSum format given a batch of samples + + Parameters + ---------- + xs : List[np.ndarray] + A batch of input samples + ys : Optional[List[float]] + A batch of labels. None means no labels available. + """ + import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel + + repeats = [x.shape[0] for x in xs] + xs = np.concatenate(xs, axis=0) + self.ids = np.concatenate([[i] * repeat for i, repeat in enumerate(repeats)], axis=0) + if ys is None: + self.dmatrix = xgb.DMatrix(data=xs, label=None) + else: + ys = np.concatenate([[y] * repeat for y, repeat in zip(ys, repeats)], axis=0) + self.dmatrix = xgb.DMatrix(data=xs, label=ys) + self.dmatrix.set_weight(ys) + + def predict_with_score(self, pred: np.ndarray) -> np.ndarray: + """Predict the labels given the block level prediction scores. + + Parameters + ---------- + pred : np.ndarray + The block level predictions + + Returns + ------- + result : np.ndarray + The predictions for each candidate. + """ + return np.bincount(self.ids, weights=pred) + + def obj_square_error(self, ys_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Implement square error loss on pack-sum format as + a custom objective function for xgboost. + + Parameters + ---------- + ys_pred: np.ndarray + The predictions + + Returns + ------- + gradient: np.ndarray + The gradient according to the xgboost format + hessian: np.ndarray + The hessian according to the xgboost format + """ + # Making prediction + ys_pred = self.predict_with_score(ys_pred) + # Propagate prediction to each block + ys_pred = ys_pred[self.ids] + # The gradient and hessian + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + gradient = ys_pred - ys + hessian = np.ones_like(gradient) + return gradient * ys, hessian * ys + + def rmse(self, ys_pred: np.ndarray) -> Tuple[str, float]: + """Evaluate RMSE (rooted mean square error) in the pack-sum format + + Parameters + ---------- + ys_pred: np.ndarray + The raw predictions + + Returns + ------- + name: str + The name of the metric + score: float + The score of the metric + """ + # Making prediction + ys_pred = self.predict_with_score(ys_pred) + # Propagate prediction to each block + ys_pred = ys_pred[self.ids] + # The RMSE + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + square_error = np.square(ys_pred - ys) + rmse = np.sqrt(square_error.mean()) + return "p-rmse", rmse + + def average_peak_score( + self, + ys_pred: np.ndarray, + n: int, + ) -> Tuple[str, float]: + """Evaluate average-peak-score@N in the pack-sum format + + Parameters + ---------- + ys_pred: np.ndarray + The raw prediction + n : int + The N in average-peak-score@N + + Returns + ------- + name: str + The name of the metric + score: float + The score of the metric + """ + ys = self.dmatrix.get_label() # type: ignore # pylint: disable=invalid-name + ys = self.predict_with_score(ys) # type: ignore # pylint: disable=invalid-name + ys = ys / np.unique(self.ids, return_counts=True)[1] # type: ignore # pylint: disable=invalid-name + ys_pred = self.predict_with_score(ys_pred) + trials = np.argsort(ys_pred)[::-1][:n] + trial_scores = ys[trials] + curve = max_curve(trial_scores) / np.max(ys) + score = np.mean(curve) + return f"a-peak@{n}", score + + +class XGBConfig(NamedTuple): + """XGBoost model configuration + + Parameters + ---------- + max_depth : int + The maximum depth. + gamma : float + The gamma. + min_child_weight : float + The minimum child weight. + eta : float + The eta, learning rate. + seed : int + The random seed. + nthread : Optional[int], + The number of threads to use. + Default is None, which means to use physical number of cores. + """ + + def to_dict(self): + xgb_params = { + "max_depth": self.max_depth, + "gamma": self.gamma, + "min_child_weight": self.min_child_weight, + "eta": self.eta, + "seed": self.seed, + "nthread": self.nthread, + } + return xgb_params + + max_depth: int = 10 + gamma: float = 0.001 + min_child_weight: float = 0 + eta: float = 0.2 + seed: int = 43 + nthread: Optional[int] = None + + +class XGBModel(PyCostModel): + """XGBoost model + + Parameters + ---------- + extractor : FeatureExtractor + The feature extractor for the model. + config : XGBConfig + The XGBoost model config. + num_warmup_samples : int + The number of samples that are used for warmup, i.e., the first few samples are predicted + with random results. + early_stopping_rounds : int + The number of rounds for early stopping. + verbose_eval : int + The verbose level when doing evaluation. + average_peak_n : int + The number to calculate average peak score. + """ + + # feature extractor + extractor: FeatureExtractor + # xgboost model config + config: XGBConfig + # behavior of randomness + num_warmup_samples: int + # evaluation + early_stopping_rounds: int + verbose_eval: int + average_peak_n: int + # states + cached_features: List[np.ndarray] + cached_mean_costs: np.ndarray + cached_normalizer: Optional[float] + booster: Optional["xgb.Booster"] + + def __init__( + self, + *, + # feature extractor + extractor: FeatureExtractor, + # xgboost model config + config: XGBConfig = XGBConfig(), + # behavior of randomness + num_warmup_samples: int = 100, + # evaluation + early_stopping_rounds: int = 50, + verbose_eval: int = 25, + average_peak_n: int = 32, + ): + super().__init__() + # feature extractor + self.extractor = extractor + # model-related + if config.nthread is None: + # use physical core number + config = config._replace(nthread=cpu_count(logical=False)) + self.config = config + # behavior of randomness + self.num_warmup_samples = num_warmup_samples + # evaluation + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.average_peak_n = average_peak_n + # states + self.cached_features = [] + self.cached_mean_costs = np.empty((0,), dtype="float64") + self.cached_normalizer = None + self.booster = None + + def load(self, path: str) -> None: + """Load the cost model from given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + Since XGBoost model trains from scratch, each time we can only load the model without the + previous cached features / results so any call of update won't use previous training data. + """ + import xgboost as xgb # pylint: disable=import-outside-toplevel + + with tempfile.TemporaryDirectory() as tmp_dir: + untar(path, tmp_dir) + self.booster = xgb.Booster() + self.booster.load_model(os.path.join(tmp_dir, "model.bin")) + self.cached_features = list( + np.load(os.path.join(tmp_dir, "cached_features.npy"), allow_pickle=True) + ) + self.cached_mean_costs = np.load( + os.path.join(tmp_dir, "cached_mean_costs.npy"), allow_pickle=True + ) + self._set_cached_normalizer() + + def save(self, path: str) -> None: + """Save the cost model to given file location. + + Parameters + ---------- + path : str + The file path. + + Note + ---- + Since XGBoost model trains from scratch, each time we can only save the model without the + previous cached features / results so any call of update won't use previous training data. + """ + import xgboost as xgb # pylint: disable=import-outside-toplevel + + if self.booster is None: + # save all the parameters + self.booster = xgb.Booster(self.config.to_dict()) + with tempfile.TemporaryDirectory() as tmp_dir: + self.booster.save_model(os.path.join(tmp_dir, "model.bin")) + np.save( + os.path.join(tmp_dir, "cached_features.npy"), + np.array(self.cached_features, dtype=object), + ) + np.save(os.path.join(tmp_dir, "cached_mean_costs.npy"), self.cached_mean_costs) + tar( + path, + [ + os.path.join(tmp_dir, "model.bin"), + os.path.join(tmp_dir, "cached_features.npy"), + os.path.join(tmp_dir, "cached_mean_costs.npy"), + ], + ) + logger.info("Saved XGBModel to %s", path) + + def update( + self, + tune_context: "TuneContext", + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + """Update the cost model given running results. + + Parameters + ---------- + tune_context : TuneContext + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + results : List[RunnerResult] + The running results of the measure candidates. + """ + assert len(candidates) == len(results) + if len(candidates) == 0: + return + # extract feature and do validation + + def _mean_cost(x: RunnerResult) -> float: + if not x.run_secs: + return 1e10 + return float(np.median([float(s) for s in x.run_secs])) + + new_features = [ + x.numpy().astype("float32") + for x in self.extractor.extract_from(tune_context, candidates) + ] + new_mean_costs = np.asarray( + [_mean_cost(x) for x in results], + dtype="float32", + ) + if self.booster is not None and self.cached_normalizer is not None: + logger.debug( + "XGB validation: %s", + "\t".join( + f"{key}: {score:.6f}" + for key, score in self._validate( + xs=new_features, + ys=new_mean_costs, + ) + ), + ) + # use together with previous features + self.cached_features.extend(new_features) + self.cached_mean_costs = np.append(self.cached_mean_costs, new_mean_costs) + self._set_cached_normalizer() + # train xgb model + self._train( + xs=self.cached_features, + ys=self.cached_mean_costs, + ) + + def predict( + self, tune_context: "TuneContext", candidates: List[MeasureCandidate] + ) -> np.ndarray: + """Predict the normalized score using the cost model. + + Parameters + ---------- + tune_context : TuneContext, + The tuning context. + candidates : List[MeasureCandidate] + The measure candidates. + + Return + ------ + result : np.ndarray + The predicted normalized score. + """ + n_measured = len(self.cached_features) + if self.booster is not None and n_measured >= self.num_warmup_samples: + features = self.extractor.extract_from(tune_context, candidates) + ret = self._predict(xs=[x.numpy().astype("float32") for x in features]) + else: + ret = np.random.uniform( + low=0, + high=1, + size=(len(candidates),), + ) + return ret.astype("float64") + + def _train( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ys: np.ndarray, + ) -> None: + import xgboost as xgb # type: ignore # pylint: disable=import-outside-toplevel + + self.d_train = PackSum( + xs=xs, + ys=self.cached_normalizer / ys, + ) + + def obj(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return self.d_train.obj_square_error(ys_pred) + + def rmse(ys_pred: np.ndarray, d_train: "xgb.DMatrix"): # type: ignore # pylint: disable = unused-argument + return self.d_train.rmse(ys_pred) + + def average_peak_score( + ys_pred: np.ndarray, d_train: "xgb.DMatrix" # type: ignore # pylint: disable = unused-argument + ): + return self.d_train.average_peak_score(ys_pred, self.average_peak_n) + + self.booster = xgb.train( + self.config.to_dict(), + self.d_train.dmatrix, + num_boost_round=10000, + obj=obj, + callbacks=[ + custom_callback( + early_stopping_rounds=self.early_stopping_rounds, + verbose_eval=self.verbose_eval, + fevals=[ + rmse, + average_peak_score, + ], + evals=[(self.d_train.dmatrix, "tr")], + ) + ], + ) + + del self.d_train + + def _predict( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ) -> np.ndarray: + d_test = PackSum(xs=xs, ys=None) + pred = self.booster.predict(d_test.dmatrix) + ret = d_test.predict_with_score(pred) + return ret + + def _validate( # type: ignore # pylint: disable=invalid-name + self, + xs: List[np.ndarray], + ys: np.ndarray, + ) -> List[Tuple[str, float]]: + """Evaluate the score of inputs. + + Parameters + ---------- + xs : List[np.ndarray] + A batch of input samples + ys : List[float] + A batch of labels + + Returns + ------- + scores: np.ndarray + The predicted result for all inputs. + """ + if self.booster is None or self.cached_normalizer is None: + return [] + + d_valid = PackSum( + xs=xs, + ys=self.cached_normalizer / ys, + ) + + def average_peak_score(ys_pred: np.ndarray): + return d_valid.average_peak_score(ys_pred, n=self.average_peak_n) + + ys_pred = self.booster.predict(d_valid.dmatrix) + eval_result: List[Tuple[str, float]] = [ + feval(ys_pred) + for feval in ( + average_peak_score, + d_valid.rmse, + ) + ] + eval_result.sort(key=make_metric_sorter("p-rmse")) + return eval_result + + def _set_cached_normalizer(self) -> None: + filtered = self.cached_mean_costs[self.cached_mean_costs > 0] + if filtered.size == 0: + self.cached_normalizer = 1.0 + else: + self.cached_normalizer = np.min(filtered) + assert self.cached_normalizer > 0 + + +def custom_callback( + early_stopping_rounds: int, + verbose_eval: int, + fevals: List[Callable], + evals: List[Tuple["xgb.DMatrix", str]], + focused_metric: str = "tr-p-rmse", +): + """Callback function for xgboost to support multiple custom evaluation functions""" + sort_key = make_metric_sorter(focused_metric=focused_metric) + + state: Dict[str, Any] = {} + + def init(env: "xgb.core.CallbackEnv"): + """Internal function""" + booster: "xgb.Booster" = env.model + + state["best_iteration"] = 0 + state["best_score"] = float("inf") + if booster is None: + assert env.cvfolds is not None + return + if booster.attr("best_score") is not None: + state["best_score"] = float(booster.attr("best_score")) + state["best_iteration"] = int(booster.attr("best_iteration")) + state["best_msg"] = booster.attr("best_msg") + else: + booster.set_attr(best_iteration=str(state["best_iteration"])) + booster.set_attr(best_score=str(state["best_score"])) + + def callback(env: "xgb.core.CallbackEnv"): + # pylint:disable = import-outside-toplevel + import xgboost as xgb + from xgboost.callback import _fmt_metric + from xgboost.core import EarlyStopException + + try: + from xgboost.training import aggcv + except ImportError: + from xgboost.callback import _aggcv as aggcv + # pylint:enable = import-outside-toplevel + + if not state: + init(env) + booster: xgb.Booster = env.model + iteration: int = env.iteration + cvfolds: List[xgb.training.CVPack] = env.cvfolds + ##### Evaluation ##### + # `eval_result` is a list of (key, score) + eval_result: List[Tuple[str, float]] = [] + if cvfolds is None: + eval_result = list( + itertools_chain.from_iterable( + [ + (key, float(value)) + for key, value in map( + lambda x: x.split(":"), + booster.eval_set( + evals=evals, + iteration=iteration, + feval=feval, + ).split()[1:], + ) + ] + for feval in fevals + ) + ) + else: + eval_result = list( + itertools_chain.from_iterable( + [ + (key, score) + for key, score, _std in aggcv( + fold.eval( + iteration=iteration, + feval=feval, + ) + for fold in cvfolds + ) + ] + for feval in fevals + ) + ) + eval_result = list(eval_result) + eval_result.sort(key=sort_key) + + ##### Print eval result ##### + if verbose_eval and iteration % verbose_eval == 0: + info = [] + for key, score in eval_result: + if "null" not in key: + info.append(f"{key}: {score:.6f}") + logger.debug("XGB iter %3d: %s", iteration, "\t".join(info)) + + ##### Choose score and do early stopping ##### + score = None + for key, _score in eval_result: + if key == focused_metric: + score = _score + break + assert score is not None + + best_score = state["best_score"] + best_iteration = state["best_iteration"] + if score < best_score: + tab = "\t" # to work with f-string + msg = f"[{env.iteration}] {tab.join([_fmt_metric(x) for x in eval_result])}" + state["best_msg"] = msg + state["best_score"] = score + state["best_iteration"] = env.iteration + # save the property to attributes, so they will occur in checkpoint. + if env.model is not None: + env.model.set_attr( + best_score=str(state["best_score"]), + best_iteration=str(state["best_iteration"]), + best_msg=state["best_msg"], + ) + elif env.iteration - best_iteration >= early_stopping_rounds: + best_msg = state["best_msg"] + if verbose_eval and env.rank == 0: + logger.debug("XGB stopped. Best iteration: %s ", best_msg) + raise EarlyStopException(best_iteration) + + return callback diff --git a/python/tvm/meta_schedule/feature_extractor/__init__.py b/python/tvm/meta_schedule/feature_extractor/__init__.py index f29c44bd1efd..83ac7426cc42 100644 --- a/python/tvm/meta_schedule/feature_extractor/__init__.py +++ b/python/tvm/meta_schedule/feature_extractor/__init__.py @@ -20,4 +20,5 @@ measure candidates for use in cost model. """ from .feature_extractor import FeatureExtractor, PyFeatureExtractor +from .per_store_feature import PerStoreFeature from .random_feature_extractor import RandomFeatureExtractor diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py new file mode 100644 index 000000000000..30572ed5b935 --- /dev/null +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""We extract one feature vector per BufferStoreNode statement in a TIR Stmt, +so we call this feature as "per-store" feature. +""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .feature_extractor import FeatureExtractor + + +# /*! +# * \brief Create a feature extractor that extracts features from each BufferStore +# * \param buffers_per_store The number of buffers in each BufferStore; Pad or truncate if +# * necessary. +# * \param arith_intensity_curve_num_samples The number of samples used in the arithmetic intensity +# * curve. +# * \param cache_line_bytes The number of bytes in a cache line. +# * \return The feature extractor created. +# */ + + +@register_object("meta_schedule.PerStoreFeature") +class PerStoreFeature(FeatureExtractor): + """PerStoreFeature extracts one feature vector per BufferStoreNode + + Parameters + ---------- + buffers_per_store : int + The number of buffers in each BufferStore; Pad or truncate if necessary. + arith_intensity_curve_num_samples : int + The number of samples used in the arithmetic intensity curve. + cache_line_bytes : int + The number of bytes in a cache line. + """ + + buffers_per_store: int + """The number of buffers in each BufferStore; Pad or truncate if necessary.""" + arith_intensity_curve_num_samples: int # pylint: disable=invalid-name + """The number of samples used in the arithmetic intensity curve.""" + cache_line_bytes: int + """The number of bytes in a cache line.""" + feature_vector_length: int + """Length of the feature vector.""" + + def __init__( + self, + buffers_per_store: int = 5, + arith_intensity_curve_num_samples: int = 10, + cache_line_bytes: int = 64, + ): + self.__init_handle_by_constructor__( + _ffi_api.FeatureExtractorPerStoreFeature, # type: ignore # pylint: disable=no-member + buffers_per_store, + arith_intensity_curve_num_samples, + cache_line_bytes, + ) diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py index 7c72a25b2378..f9f2f287fd11 100644 --- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -17,7 +17,7 @@ """Random Feature Extractor.""" from typing import List, Union, Tuple -import numpy as np # type: ignore +import numpy as np from tvm.runtime.ndarray import NDArray, array from ..tune_context import TuneContext diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 47003c6faa25..89c5dc3c8c21 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -20,6 +20,7 @@ from tvm._ffi import register_object from tvm.ir import IRModule, transform +from tvm.meta_schedule.database.database import Database from tvm.relay import Any, Function as RelayFunc, vm from tvm.runtime import NDArray, Object from tvm.target import Target @@ -174,10 +175,16 @@ def __init__(self) -> None: @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): - pass + """An integration context that allows application of historically best records from a database""" + database: Database + """ The database to be queried from""" -def extract_task( + def __init__(self, database) -> None: + self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member + + +def extract_task_from_relay( mod: Union[IRModule, RelayFunc], target: Target, params: Optional[Dict[str, NDArray]] = None, diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py new file mode 100644 index 000000000000..7981ea20aed9 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.mutator package. +Meta Schedule mutator that mutates the trace to explore the +design space. +""" +from .mutator import Mutator, PyMutator +from .mutate_parallel import MutateParallel +from .mutate_unroll import MutateUnroll +from .mutate_tile_size import MutateTileSize +from .mutate_compute_location import MutateComputeLocation diff --git a/python/tvm/meta_schedule/mutator/mutate_compute_location.py b/python/tvm/meta_schedule/mutator/mutate_compute_location.py new file mode 100644 index 000000000000..bb361247bf62 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_compute_location.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A mutator that mutates the compute-at location decision of SampleComputeLocation""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateComputeLocation") +class MutateComputeLocation(Mutator): + """A mutator that mutates the compute-at location decision of SampleComputeLocation""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py new file mode 100644 index 000000000000..c66dddb825f4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the parallel extent""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateParallel") +class MutateParallel(Mutator): + """Mutator that mutates the parallel extent""" + + def __init__(self, max_jobs_per_core: int) -> None: + """Mutator that mutates the parallel extent""" + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateParallel, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + ) diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py new file mode 100644 index 000000000000..9c94d4436143 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates the tile size""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateTileSize") +class MutateTileSize(Mutator): + """Mutator that mutates the tile size""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py new file mode 100644 index 000000000000..f81953d008d4 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Mutator that mutates auto unroll step""" +from tvm._ffi.registry import register_object + +from .. import _ffi_api +from .mutator import Mutator + + +@register_object("meta_schedule.MutateUnroll") +class MutateUnroll(Mutator): + """Mutator that mutates auto unroll step""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py new file mode 100644 index 000000000000..d3b008591168 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Mutator.""" +from typing import Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Trace + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class Mutator(Object): + """Mutator is designed to mutate the trace to explore the design space.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the mutator with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the mutator. + """ + _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, trace: Trace) -> Optional[Trace]: + """Apply the mutator function to the given trace. + + Parameters + ---------- + trace : Trace + The given trace for mutation. + + Returns + ------- + trace : Optional[Trace] + None if mutator failed, otherwise return the mutated trace. + """ + return _ffi_api.MutatorApply(self, trace, -1) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyMutator") +class PyMutator(Mutator): + """An abstract mutator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Mutator) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Mutator) + def f_apply(trace: Trace, _) -> Optional[Trace]: + return self.apply(trace) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py new file mode 100644 index 000000000000..50fbb0e0852b --- /dev/null +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The tvm.meta_schedule.postproc package.""" +from .postproc import Postproc, PyPostproc +from .disallow_dynamic_loop import DisallowDynamicLoop +from .rewrite_cooperative_fetch import RewriteCooperativeFetch +from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll +from .rewrite_reduction_block import RewriteReductionBlock +from .rewrite_unbound_block import RewriteUnboundBlock +from .rewrite_tensor_core import RewriteTensorCore +from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py new file mode 100644 index 000000000000..5515d288e0e7 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that checks if the IRModule has any loop with non-constant extent""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.DisallowDynamicLoop") +class DisallowDynamicLoop(Postproc): + """A postprocessor that checks if the IRModule has any loop with non-constant extent""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py new file mode 100644 index 000000000000..8e3b332c77c4 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Postproc.""" + +from typing import TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.Postproc") +class Postproc(Object): + """Rules to apply a postprocessor to a schedule.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the postprocessor with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the postprocessor. + """ + _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, sch: Schedule) -> bool: + """Apply a postprocessor to the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be post processed. + + Returns + ------- + result : bool + Whether the postprocessor was successfully applied. + """ + return _ffi_api.PostprocApply(self, sch) # type: ignore # pylint: disable=no-member + + +@register_object("meta_schedule.PyPostproc") +class PyPostproc(Postproc): + """An abstract Postproc with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Postproc) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Postproc) + def f_apply(sch: Schedule) -> bool: + return self.apply(sch) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"{self.__class__.__name__}({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py new file mode 100644 index 000000000000..e2d7c2212382 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites the cooperative fetch annotation to actual +vectorized cooperative fetching in loop bindings.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteCooperativeFetch") +class RewriteCooperativeFetch(Postproc): + """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized + cooperative fetching in loop bindings. + """ + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..abe7288acba9 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that applies parallelization, vectorization and auto unrolling +according to the annotation of each block""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteParallelVectorizeUnroll") +class RewriteParallelVectorizeUnroll(Postproc): + """A postprocessor that applies parallelization, vectorization and auto unrolling + according to the annotation of each block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py new file mode 100644 index 000000000000..7e15ed493ccb --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that rewrites reduction block by moving the init block out.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteReductionBlock") +class RewriteReductionBlock(Postproc): + """A postprocessor that rewrites reduction block by moving the init block out.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensor_core.py b/python/tvm/meta_schedule/postproc/rewrite_tensor_core.py new file mode 100644 index 000000000000..f858fed3a6d4 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_tensor_core.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that tensorize Tensor Core related components.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteTensorCore") +class RewriteTensorCore(Postproc): + """A postprocessor that tensorize Tensor Core related components.""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteTensorCore, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py new file mode 100644 index 000000000000..f4113e5173c9 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that adds thread binding to unbound blocks""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteUnboundBlock") +class RewriteUnboundBlock(Postproc): + """A postprocessor that adds thread binding to unbound blocks""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py new file mode 100644 index 000000000000..501e4423196c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that verifies if the GPU code is correct""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.VerifyGPUCode") +class VerifyGPUCode(Postproc): + """A postprocessor that verifies if the GPU code is correct""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index b1a9c678c6fc..6af403905cb4 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -33,7 +33,7 @@ run_evaluator_common, ) -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) class LocalRunnerFuture(RunnerFuture): diff --git a/python/tvm/meta_schedule/runner/rpc_runner.py b/python/tvm/meta_schedule/runner/rpc_runner.py index a6fb169fa590..6085de809767 100644 --- a/python/tvm/meta_schedule/runner/rpc_runner.py +++ b/python/tvm/meta_schedule/runner/rpc_runner.py @@ -38,6 +38,8 @@ run_evaluator_common, ) +logger = logging.getLogger(__name__) + logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index b90780d5bfdb..f03c6de3df4b 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -1,3 +1,6 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance @@ -16,4 +19,10 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ +from .add_rfactor import AddRFactor +from .auto_inline import AutoInline +from .cross_thread_reduction import CrossThreadReduction +from .multi_level_tiling import MultiLevelTiling, ReuseType +from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll +from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py new file mode 100644 index 000000000000..72f9fc92f96e --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Add-rfactor Rule that add-rfactor to some blocks if needed""" +from typing import Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AddRFactor") +class AddRFactor(ScheduleRule): + """Rules for add-rfactor to some blocks if needed. + + Parameters + ---------- + max_jobs_per_core: int + The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU + parallelism, i.e. `num_cores * max_jobs_per_core`. + Use -1 to disable parallelism. + max_innermost_factor: Optional[int] = None + The maximum size of the innermost factor. None means no limit. + """ + + def __init__( + self, + max_jobs_per_core: int = 16, + max_innermost_factor: Optional[int] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + max_innermost_factor, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py new file mode 100644 index 000000000000..83828586bfb2 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.AutoInline") +class AutoInline(ScheduleRule): + """Rule that inlines spatial blocks if it satisfies some conditions + + Parameters + ---------- + into_producer : bool + If allows to inline a block into its producer + into_consumer : bool + If allows to inline a block into its consumer + into_cache_only : bool + If it only allows to inline into a block generated by cache_read/write + inline_const_tensor : bool + Always inline constant tensors + disallow_if_then_else : bool + Always disallow if-then-else-like constructs + require_injective : bool + Always require the read-to-write mapping to be ordered + require_ordered : bool + Always require the read-to-write mapping to be injective + disallow_op : Optional[List[str]] + The operators that are disallowed in auto inline + """ + + def __init__( + self, + into_producer: bool, + into_consumer: bool, + into_cache_only: bool, + inline_const_tensor: bool, + disallow_if_then_else: bool, + require_injective: bool, + require_ordered: bool, + disallow_op: Optional[List[str]] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member + into_producer, + into_consumer, + into_cache_only, + inline_const_tensor, + disallow_if_then_else, + require_injective, + require_ordered, + disallow_op, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py new file mode 100644 index 000000000000..f242e42aea4b --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" +from typing import List + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.CrossThreadReduction") +class CrossThreadReduction(ScheduleRule): + """A schedule rule which applies cross-thread reduction to some reduction blocks + correspondingly when needed + + Parameters + ---------- + thread_extents: List[int] + Candidates of thread axis extent (values are required to be positive). + """ + + def __init__(self, thread_extents: List[int]) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member + thread_extents, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py new file mode 100644 index 000000000000..9e030d8a425c --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multi-level tiling with reuse.""" +from typing import Any, Dict, List, Literal, NamedTuple, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +class ReuseType(NamedTuple): + """Reuse type.""" + + req: Literal["no", "may", "must"] + levels: List[int] + scope: str + + def as_dict(self) -> Dict[str, Any]: + """Return the dict representation of the reuse type.""" + return { + "req": self.req, + "levels": self.levels, + "scope": self.scope, + } + + +@register_object("meta_schedule.MultiLevelTiling") +class MultiLevelTiling(ScheduleRule): + """Multi-level tiling with reuse. + + Parameters + ---------- + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + use_tensor_core : bool + Whether to apply tensor core wmma intrinsic for the computation + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + tile_binds: Optional[List[str]] = None, + use_tensor_core: bool = False, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member + structure, + tile_binds, + use_tensor_core, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py new file mode 100644 index 000000000000..36513022a923 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rule that mark parallelize, vectorize and unroll to each block correspondingly""" +from typing import List, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.ParallelizeVectorizeUnroll") +class ParallelizeVectorizeUnroll(ScheduleRule): + """Rule that mark parallelize, vectorize and unroll to each block correspondingly + + Parameters + ---------- + max_jobs_per_core: int + The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU + parallelism, i.e. `num_cores * max_jobs_per_core`. + Use -1 to disable parallelism. + max_vectorize_extent: int + The maximum extent to be vectorized. It sets the uplimit of the CPU vectorization. + Use -1 to disable vectorization. + unroll_max_steps: Optional[List[int]] + The maximum number of unroll steps to be done. + Use None to disable unroll + unroll_explicit: bool + Whether to explicitly unroll the loop, or just add a unroll pragma + """ + + def __init__( + self, + max_jobs_per_core: int = 16, + max_vectorize_extent: int = 16, + unroll_max_steps: Optional[List[int]] = None, + unroll_explicit: bool = True, + ) -> None: + if unroll_max_steps is None: + unroll_max_steps = [] + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleParallelizeVectorizeUnroll, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + max_vectorize_extent, + unroll_max_steps, + unroll_explicit, + ) diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py new file mode 100644 index 000000000000..2355b0bfa8e5 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rule that randomly select a compute-at location for a free block""" +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.RandomComputeLocation") +class RandomComputeLocation(ScheduleRule): + """A rule that randomly select a compute-at location for a free block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 298cdae4283a..174672235b42 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -19,5 +19,8 @@ Meta Schedule search strategy utilizes the design spaces given to generate measure candidates. """ -from .search_strategy import MeasureCandidate, PySearchStrategy, SearchStrategy -from .replay_trace import ReplayTrace + +from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate +from .replay_trace import ReplayTrace, ReplayTraceConfig +from .replay_func import ReplayFunc, ReplayFuncConfig +from .evolutionary_search import EvolutionarySearch, EvolutionarySearchConfig diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py new file mode 100644 index 000000000000..bfc5df52b1c8 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Evolutionary Search Strategy""" + +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.EvolutionarySearch") +class EvolutionarySearch(SearchStrategy): + """ + Replay Trace Search Strategy is a search strategy that always replays the trace by removing its + decisions so that the decisions would be randomly re-generated. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + population_size : int + The initial population of traces from measured samples and randomly generated samples. + init_measured_ratio : int + The ratio of measured samples in the initial population. + init_min_unmeasured : int + The minimal size of unmeasured population in the initial sampling. + genetic_num_iters : int + The number of iterations for genetic algorithm. + genetic_mutate_prob : float + The probability of mutation. + genetic_max_fail_count : int + The maximum number to retry mutation. + eps_greedy : float + The ratio of greedy selected samples in the final picks. + """ + + num_trials_per_iter: int + num_trials_total: int + population_size: int + init_measured_ratio: int + init_min_unmeasured: int + genetic_num_iters: int + genetic_mutate_prob: float + genetic_max_fail_count: int + eps_greedy: float + + def __init__( + self, + *, + num_trials_per_iter: int, + num_trials_total: int, + population_size: int, + init_measured_ratio: float, + init_min_unmeasured: int, + genetic_num_iters: int, + genetic_mutate_prob: float, + genetic_max_fail_count: int, + eps_greedy: float, + ) -> None: + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyEvolutionarySearch, # type: ignore # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + population_size, + init_measured_ratio, + init_min_unmeasured, + genetic_num_iters, + genetic_mutate_prob, + genetic_max_fail_count, + eps_greedy, + ) + + +class EvolutionarySearchConfig(NamedTuple): + """Configuration for EvolutionarySearch""" + + num_trials_per_iter: int + num_trials_total: int + population_size: int = 2048 + init_measured_ratio: float = 0.2 + init_min_unmeasured: int = 50 + genetic_num_iters: int = 4 + genetic_mutate_prob: float = 0.85 + genetic_max_fail_count: int = 10 + eps_greedy: float = 0.05 + + def create_strategy(self) -> EvolutionarySearch: + return EvolutionarySearch( + num_trials_per_iter=self.num_trials_per_iter, + num_trials_total=self.num_trials_total, + population_size=self.population_size, + init_measured_ratio=self.init_measured_ratio, + init_min_unmeasured=self.init_min_unmeasured, + genetic_num_iters=self.genetic_num_iters, + genetic_mutate_prob=self.genetic_mutate_prob, + genetic_max_fail_count=self.genetic_max_fail_count, + eps_greedy=self.eps_greedy, + ) diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py new file mode 100644 index 000000000000..34eadc7a3f6b --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" +from typing import NamedTuple + +from tvm._ffi import register_object + +from .. import _ffi_api +from .search_strategy import SearchStrategy + + +@register_object("meta_schedule.ReplayFunc") +class ReplayFunc(SearchStrategy): + """ + Replay Func Search Strategy is a search strategy that generates measure candidates by + calling a design space generator and transform the design space. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyReplayFunc, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) + + +class ReplayFuncConfig(NamedTuple): + """Configuration for ReplayFunc""" + + num_trials_per_iter: int + num_trials_total: int + + def create_strategy(self) -> ReplayFunc: + return ReplayFunc(self.num_trials_per_iter, self.num_trials_total) diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 15f8295f2524..f55013546021 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +from typing import NamedTuple from tvm._ffi import register_object from .search_strategy import SearchStrategy @@ -41,7 +42,17 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyReplayTrace, # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) + + +class ReplayTraceConfig(NamedTuple): + """Configuration for ReplayTrace""" + + num_trials_per_iter: int + num_trials_total: int + + def create_strategy(self) -> ReplayTrace: + return ReplayTrace(self.num_trials_per_iter, self.num_trials_total) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 6cee09edd4fc..0c85d809796d 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -48,7 +48,11 @@ class MeasureCandidate(Object): sch: Schedule args_info: List[ArgInfo] - def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + def __init__( + self, + sch: Schedule, + args_info: List[ArgInfo], + ) -> None: """Constructor. Parameters @@ -72,10 +76,7 @@ class SearchStrategy(Object): before usage and post-tuned after usage. """ - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters @@ -111,15 +112,29 @@ def generate_measure_candidates(self) -> Optional[List[MeasureCandidate]]: """ return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member - def notify_runner_results(self, results: List[RunnerResult]) -> None: + def notify_runner_results( + self, + tune_context: "TuneContext", + measure_candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: """Update the search strategy with profiling results. Parameters ---------- + tune_context : TuneContext + The tuning context for update. + measure_candidates : List[MeasureCandidate] + The measure candidates for update. results : List[RunnerResult] The profiling results from the runner. """ - _ffi_api.SearchStrategyNotifyRunnerResults(self, results) # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member + self, + tune_context, + measure_candidates, + results, + ) @register_object("meta_schedule.PySearchStrategy") @@ -146,8 +161,12 @@ def f_generate_measure_candidates() -> List[MeasureCandidate]: return self.generate_measure_candidates() @check_override(self.__class__, SearchStrategy) - def f_notify_runner_results(results: List["RunnerResult"]) -> None: - self.notify_runner_results(results) + def f_notify_runner_results( + tune_context: "TuneContext", + measure_candidates: List[MeasureCandidate], + results: List["RunnerResult"], + ) -> None: + self.notify_runner_results(tune_context, measure_candidates, results) self.__init_handle_by_constructor__( _ffi_api.SearchStrategyPySearchStrategy, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 80f372a448f5..a9b2d560314a 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -32,5 +32,5 @@ class PostOrderApply(SpaceGenerator): def __init__(self): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, # pylint: disable=no-member ) diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 391011b4f53f..a63d9a3f2183 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -16,13 +16,15 @@ # under the License. """Round Robin Task Scheduler""" -from typing import List, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner from ..database import Database +from ..cost_model import CostModel from .task_scheduler import TaskScheduler from .. import _ffi_api @@ -33,7 +35,21 @@ @register_object("meta_schedule.RoundRobin") class RoundRobin(TaskScheduler): - """Round Robin Task Scheduler""" + """Round Robin Task Scheduler + + Parameters + ---------- + tasks: List[TuneContext] + The list of tune context to process. + builder: Builder + The builder of the scheduler. + runner: Runner + The runner of the scheduler. + database: Database + The database of the scheduler. + measure_callbacks: Optional[List[MeasureCallback]] = None + The list of measure callbacks of the scheduler. + """ def __init__( self, @@ -41,6 +57,8 @@ def __init__( builder: Builder, runner: Runner, database: Database, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, ) -> None: """Constructor. @@ -54,6 +72,8 @@ def __init__( The runner. database : Database The database. + measure_callbacks: Optional[List[MeasureCallback]] + The list of measure callbacks of the scheduler. """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member @@ -61,4 +81,6 @@ def __init__( builder, runner, database, + cost_model, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index aeea154cfe02..dd8e3fe89b63 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -16,14 +16,16 @@ # under the License. """Auto-tuning Task Scheduler""" -from typing import List +from typing import List, Optional from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object from ..runner import Runner from ..builder import Builder from ..database import Database +from ..cost_model import CostModel from ..tune_context import TuneContext from .. import _ffi_api from ..utils import check_override @@ -43,12 +45,16 @@ class TaskScheduler(Object): The runner of the scheduler. database: Database The database of the scheduler. + measure_callbacks: List[MeasureCallback] = None + The list of measure callbacks of the scheduler. """ tasks: List[TuneContext] builder: Builder runner: Runner database: Database + cost_model: Optional[CostModel] + measure_callbacks: List[MeasureCallback] def tune(self) -> None: """Auto-tuning.""" @@ -59,7 +65,7 @@ def next_task_id(self) -> int: Returns ------- - int + next_task_id : int The next task id. """ return _ffi_api.TaskSchedulerNextTaskId(self) # type: ignore # pylint: disable=no-member @@ -94,7 +100,7 @@ def _is_task_running(self, task_id: int) -> bool: Returns ------- - bool + running : bool Whether the task is running. """ return _ffi_api.TaskSchedulerIsTaskRunning(self, task_id) # type: ignore # pylint: disable=no-member @@ -120,6 +126,8 @@ def __init__( builder: Builder, runner: Runner, database: Database, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, ): """Constructor. @@ -133,6 +141,10 @@ def __init__( The runner of the scheduler. database: Database The database of the scheduler. + cost_model: Optional[CostModel] + The cost model of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ @check_override(self.__class__, TaskScheduler, required=False) @@ -173,6 +185,8 @@ def f_join_running_task(task_id: int) -> None: builder, runner, database, + cost_model, + measure_callbacks, f_tune, f_initialize_task, f_set_task_stopped, diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 7e516a510f66..b64891a3858d 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -15,5 +15,8 @@ # specific language governing permissions and limitations # under the License. """Testing utilities in meta schedule""" +from . import te_workload +from . import schedule_rule from .local_rpc import LocalRPC -from .relay_workload import get_network +from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model +from .te_workload import create_te_workload diff --git a/python/tvm/meta_schedule/testing/byoc_trt.py b/python/tvm/meta_schedule/testing/byoc_trt.py new file mode 100644 index 000000000000..01d4a1fe3cae --- /dev/null +++ b/python/tvm/meta_schedule/testing/byoc_trt.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""TensorRT-MetaSchedule integration""" +# pylint: disable=import-outside-toplevel + +from typing import Dict, List, TYPE_CHECKING + +if TYPE_CHECKING: + from tvm.ir import IRModule + from tvm.target import Target + from tvm.runtime import NDArray, Module, Device + from tvm.meta_schedule.runner import EvaluatorConfig + + +def build_relay( + mod: "IRModule", + target: "Target", + params: Dict[str, "NDArray"], +) -> "Module": + """Build a Relay IRModule + + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + target : Target + The target to build the module for. + params : Dict[str, NDArray] + The parameter dict to build the module with. + + Returns + ------- + mod : runtime.Module + The built module. + """ + from tvm.relay.build_module import _build_module_no_factory as relay_build + from tvm.runtime import Module + + result = relay_build(mod, target=target, target_host=None, params=params) + assert isinstance(result, Module) + return result + + +def build_relay_with_tensorrt( + mod: "IRModule", + target: "Target", + params: Dict[str, "NDArray"], +) -> "Module": + """Build a Relay IRModule with TensorRT BYOC + + Parameters + ---------- + mod : IRModule + The Relay IRModule to build. + + target : Target + The target to build the module for. + + params : Dict[str, NDArray] + The parameter dict to build the module with. + + Returns + ------- + mod : runtime.Module + The built module. + """ + from tvm.ir.transform import PassContext + from tvm.relay.op.contrib import tensorrt + from tvm.relay.build_module import _build_module_no_factory as relay_build + from tvm.runtime import Module + + mod, config = tensorrt.partition_for_tensorrt(mod, params) + with PassContext( + opt_level=3, + config={"relay.ext.tensorrt.options": config}, + ): + result = relay_build(mod, target=target, target_host=None, params=params) + assert isinstance(result, Module) + return result + + +def run_with_graph_executor( + rt_mod: "Module", + device: "Device", + evaluator_config: "EvaluatorConfig", + repeated_args: List["NDArray"], +) -> List[float]: + """Run a Relay module with GraphExecutor + + Parameters + ---------- + rt_mod : Module + The Relay module to run. + device : Device + The device to run the module on. + evaluator_config : EvaluatorConfig + The evaluator configuration to run the module with. + repeated_args : List[NDArray] + The list of repeated arguments to run the module with. + + Returns + ------- + results : List[float] + The list of results. + """ + import itertools + from tvm.contrib.graph_executor import GraphModule + + graph_mod = GraphModule(rt_mod["default"](device)) + evaluator = graph_mod.module.time_evaluator( + func_name="run", + dev=device, + number=evaluator_config.number, + repeat=evaluator_config.repeat, + min_repeat_ms=evaluator_config.min_repeat_ms, + f_preproc="cache_flush_cpu_non_first_arg" + if evaluator_config.enable_cpu_cache_flush + else "", + ) + repeated_costs = [] + for args in repeated_args: + profile_result = evaluator(*args) + repeated_costs.append(profile_result.results) + costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)] + return costs diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 1eb9950f7fc7..bf9287a8eb18 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -15,13 +15,230 @@ # specific language governing permissions and limitations # under the License. """Workloads in Relay IR""" +from enum import Enum from typing import Dict, Tuple -import tvm.relay.testing # pylint: disable=unused-import from tvm import relay from tvm.ir import IRModule from tvm.runtime import NDArray +# Model types supported in Torchvision +class MODEL_TYPE(Enum): # pylint: disable=invalid-name + IMAGE_CLASSIFICATION = (1,) + VIDEO_CLASSIFICATION = (2,) + SEGMENTATION = (3,) + OBJECT_DETECTION = (4,) + TEXT_CLASSIFICATION = (5,) + + +# Specify the type of each model +MODEL_TYPES = { + # Image classification models + "resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnet50": MODEL_TYPE.IMAGE_CLASSIFICATION, + "alexnet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "vgg16": MODEL_TYPE.IMAGE_CLASSIFICATION, + "squeezenet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet121": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet161": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet169": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet201": MODEL_TYPE.IMAGE_CLASSIFICATION, + "inception_v3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "googlenet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "shufflenet_v2_x1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_large": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_small": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnext50_32x4d": MODEL_TYPE.IMAGE_CLASSIFICATION, + "wide_resnet50_2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mnasnet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b1": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b4": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b5": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b6": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b7": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + # Semantic Segmentation models + "fcn_resnet50": MODEL_TYPE.SEGMENTATION, + "fcn_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet50": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + "lraspp_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + # Object detection models + # @Sung: Following networks are not runnable since Torch frontend cannot handle aten::remainder. + # "retinanet_resnet50_fpn", "keypointrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_320_fpn": MODEL_TYPE.OBJECT_DETECTION, + "retinanet_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "maskrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "keypointrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "ssd300_vgg16": MODEL_TYPE.OBJECT_DETECTION, + "ssdlite320_mobilenet_v3_large": MODEL_TYPE.OBJECT_DETECTION, + # Video classification + "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + # Text classification + "bert_tiny": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_base": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_medium": MODEL_TYPE.TEXT_CLASSIFICATION, + "bert_large": MODEL_TYPE.TEXT_CLASSIFICATION, +} + + +def get_torch_model( + model_name: str, + input_shape: Tuple[int, ...], + output_shape: Tuple[int, int], # pylint: disable=unused-argument + dtype: str = "float32", +) -> Tuple[IRModule, Dict[str, NDArray]]: + """Load model from torch model zoo + Parameters + ---------- + model_name : str + The name of the model to load + input_shape: Tuple[int, ...] + Tuple for input shape + output_shape: Tuple[int, int] + Tuple for output shape + dtype: str + Tensor data type + """ + + assert dtype == "float32" + + import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel + from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel + import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel + import os # type: ignore # pylint: disable=import-error,import-outside-toplevel + + def do_trace(model, inp): + model.eval() + model_trace = torch.jit.trace(model, inp) + model_trace.eval() + return model_trace + + # Load model from torchvision + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + model = getattr(models, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + model = getattr(models.segmentation, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + model = getattr(models.detection, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + model = getattr(models.video, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + config_dict = { + "bert_tiny": transformers.BertConfig( + num_hidden_layers=6, + hidden_size=512, + intermediate_size=2048, + num_attention_heads=8, + return_dict=False, + ), + "bert_base": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=768, + intermediate_size=3072, + num_attention_heads=12, + return_dict=False, + ), + "bert_medium": transformers.BertConfig( + num_hidden_layers=12, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + "bert_large": transformers.BertConfig( + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + return_dict=False, + ), + } + configuration = config_dict[model_name] + model = transformers.BertModel(configuration) + A = torch.randint(10000, input_shape) + + model.eval() + scripted_model = torch.jit.trace(model, [A], strict=False) + + shape_list = [("input_ids", input_shape)] + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params + else: + raise ValueError("Unsupported model in Torch model zoo.") + + # Setup input + input_data = torch.randn(input_shape).type(torch.float32) + shape_list = [("input0", input_shape)] + + # Get trace. Depending on the model type, wrapper may be necessary. + if MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + scripted_model = do_trace(wrapped_model, input_data) + + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + + def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + class TraceWrapper(torch.nn.Module): # type: ignore + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + _ = wrapped_model(input_data) + scripted_model = do_trace(wrapped_model, input_data) + else: + scripted_model = do_trace(model, input_data) + + # Convert torch model to relay module + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params + def get_network( name: str, @@ -30,6 +247,8 @@ def get_network( dtype: str = "float32", ) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: """Get the symbol definition and random weight of a network""" + import tvm.relay.testing # pylint: disable=import-outside-toplevel,unused-import + # meta-schedule prefers NHWC layout if layout == "NHWC": image_shape = (224, 224, 3) diff --git a/python/tvm/meta_schedule/testing/run_ansor.sh b/python/tvm/meta_schedule/testing/run_ansor.sh new file mode 100644 index 000000000000..d5ea9df34485 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_ansor.sh @@ -0,0 +1,40 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +NUM_TRIALS=800 +LOG_DIR=$HOME/logs/ansor-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor_cpu.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/python/tvm/meta_schedule/testing/run_meta_schedule.sh b/python/tvm/meta_schedule/testing/run_meta_schedule.sh new file mode 100644 index 000000000000..fa0c7ca42562 --- /dev/null +++ b/python/tvm/meta_schedule/testing/run_meta_schedule.sh @@ -0,0 +1,38 @@ +# set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +LOG_DIR=$HOME/logs/ms-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_tune_te_cpu.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials 5000 \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +# run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +# run NRM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py new file mode 100644 index 000000000000..83434a123a03 --- /dev/null +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default schedule rules""" +from typing import List + +from tvm.meta_schedule.schedule_rule import ( + AddRFactor, + AutoInline, + CrossThreadReduction, + MultiLevelTiling, + ParallelizeVectorizeUnroll, + RandomComputeLocation, + ReuseType, + ScheduleRule, +) +from tvm.target import Target + + +def get(target: Target) -> List[ScheduleRule]: + """Default schedule rules""" + if target.kind.name == "llvm": + return [ + auto_inline(target), + add_rfactor(target), + multi_level_tiling(target), + parallel_vectorize_unroll(target), + random_compute_location(target), + ] + if target.kind.name == "cuda": + return [ + multi_level_tiling(target), + auto_inline_after_tiling(target), + cross_thread_reduction(target), + parallel_vectorize_unroll(target), + ] + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def auto_inline(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def auto_inline_after_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for auto inline after tiling""" + if target.kind.name == "llvm": + return AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ) + if target.kind.name == "cuda": + return AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def multi_level_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling and reuse""" + if target.kind.name == "llvm": + return MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ) + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling with tensor core and reuse""" + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], + use_tensor_core=True, + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def parallel_vectorize_unroll(target: Target) -> ScheduleRule: + """Default schedule rules for with parallel-vectorize-unroll""" + if target.kind.name == "llvm": + return ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=32, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ) + if target.kind.name == "cuda": + return ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, + max_vectorize_extent=-1, + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def random_compute_location(target: Target) -> ScheduleRule: + """Default schedule rules for with random-compute-location""" + if target.kind.name == "llvm": + return RandomComputeLocation() + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def add_rfactor(target: Target) -> ScheduleRule: + """Default schedule rules for with add_rfactor""" + if target.kind.name == "llvm": + return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) + raise NotImplementedError(f"{target.kind.name} is not supported") + + +def cross_thread_reduction(target: Target) -> ScheduleRule: + """Default schedule rules for with cross-thread reduction""" + if target.kind.name == "cuda": + return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py new file mode 100644 index 000000000000..4abf090ddf95 --- /dev/null +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List, Union + +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.target import Target +from tvm.tir import PrimFunc, Schedule +from tvm.tir.schedule import Trace + +from . import schedule_rule as sch_rule + + +def create_context(mod: Union[IRModule, PrimFunc], target: Target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=sch_rule.get(target), + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for rule in ctx.sch_rules: + rule.initialize_with_tune_context(ctx) + return ctx + + +def check_trace(spaces: List[Schedule], expected: List[List[str]]): + expected_traces = {"\n".join(t) for t in expected} + actual_traces = set() + for space in spaces: + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + str_trace = "\n".join(str(trace).strip().splitlines()) + actual_traces.add(str_trace) + assert str_trace in expected_traces, "\n" + str_trace + assert len(expected_traces) == len(actual_traces) + + +def debug_print_spaces(spaces: List[Schedule], trace_as_list: bool) -> None: + for i, space in enumerate(spaces): + print(f"##### Space {i}") + print(space.mod.script()) + trace = Trace(space.trace.insts, {}) + trace = trace.simplified(remove_postproc=True) + if trace_as_list: + print(str(trace).strip().splitlines()) + else: + print(trace) diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py new file mode 100644 index 000000000000..9133b8587a95 --- /dev/null +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -0,0 +1,877 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in TE""" +# pylint: disable=missing-docstring +from typing import Tuple + +from tvm import te, tir, topi + + +def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X") + y = te.placeholder((B, K, M), name="Y") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]), + name="Z", + ) + return (x, y, z) + + +def conv1d_nlc( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, output) + + +def conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO), name="weight") + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rh, rw, rc, co] + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv3d_ndhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + D: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, D, H, W, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, kernel_size, CI // groups, CO), name="weight" + ) + batch_size, in_d, in_h, in_w, _ = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name="rd") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + ( + padded[ + n, + d * stride + rd * dilation, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rd, rh, rw, rc, co] + ), + axis=[rd, rh, rw, rc], + ), + name="conv3d_ndhwc", + ) + return (inputs, weight, output) + + +def depthwise_conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + C: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + factor: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + assert int(factor) == 1, "Not optimized for factor != 1" + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + c // factor, + ] + * weight[c % factor, rh, rw, c // factor] + ), + axis=[rh, rw], + ), + name="depth_conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_transpose_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight") + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple( + padding, (filter_h, filter_w) + ) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad( + inputs, + [ + 0, + (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, + 0, + ], + [ + 0, + (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, + 0, + ], + ) + + # remove extra padding introduced by dilatation + idx_div = te.indexdiv + idx_mod = te.indexmod + border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h) + border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idx_div(indices[i], strides[i])) + not_zero.append(idx_mod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name="rc") + rh = te.reduce_axis((0, filter_h), name="rh") + rw = te.reduce_axis((0, filter_w), name="rw") + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) + * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc], + ), + name="conv2d_transpose_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_capsule_nhwijc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + capsule_size: int = 4, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name="weight" + ) + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name="cap_k") + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + ( + padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co] + ), + axis=[rh, rw, cap_k, rc], + ), + name="conv2d_capsule_nhwijc", + ) + return (inputs, weight, output) + + +def norm_bmn( # pylint: disable=invalid-name,missing-docstring + B: int, + M: int, + N: int, +) -> Tuple[te.Tensor, te.Tensor]: + a = te.placeholder((B, M, N), name="A") + i = te.reduce_axis((0, M), name="i") + j = te.reduce_axis((0, N), name="j") + c = te.compute( + (B,), + lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]), + name="C", + ) + d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D") + return (a, d) + + +def conv2d_nhwc_without_layout_rewrite( # pylint: disable=invalid-name + Input: int, + Filter: int, + stride: int, + padding: int, + dilation: int, + out_dtype="float32", +): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation + because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape # type: ignore + kernel_h, kernel_w, _channel, num_filter = Filter.shape # type: ignore + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel = num_filter + out_height = topi.utils.simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 + ) + out_width = topi.utils.simplify( + (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 + ) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[ + nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc + ].astype(out_dtype) + * Filter[ry, rx, rc, ff].astype(out_dtype), # type: ignore + axis=[ry, rx, rc], + ), + name="Conv2dOutput", + tag="conv2d_nhwc", + ) + return Output + + +def conv2d_nhwc_bn_relu( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + strides: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + data = te.placeholder((N, H, W, CI), name="data") + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel") + bias = te.placeholder((CO,), name="bias") + bn_scale = te.placeholder((CO,), name="bn_scale") + bn_offset = te.placeholder((CO,), name="bn_offset") + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l], name="bias_add" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], name="bn_mul" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], name="bn_add" + ) + out = topi.nn.relu(conv) + return (data, kernel, bias, bn_offset, bn_scale, out) + + +def transpose_batch_matmul( # pylint: disable=invalid-name,missing-docstring + batch: int, + seq_len: int, + n_head: int, + n_dim: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + query = te.placeholder((batch, seq_len, n_head, n_dim), name="query") + value = te.placeholder((batch, seq_len, n_head, n_dim), name="value") + query_T = te.compute( + (batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], + name="query_T", + ) + value_T = te.compute( + (batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], + name="value_T", + ) + k = te.reduce_axis((0, n_dim), name="k") + out = te.compute( + (batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j], axis=[k]), + name="C", + ) + return (query, value, out) + + +def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + tile_size = 4 # _infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name="inputs") + N, H, W, CI = topi.utils.get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, "float32") + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + _rkh = te.reduce_axis((0, KH), name="r_kh") + _rkw = te.reduce_axis((0, KW), name="r_kw") + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][ + idxmod(p, nW) * m + nu + ][ci], + name="input_tile", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + data_pack = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: te.sum( + input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + bgemm = te.compute( + (alpha, alpha, P, CO), + lambda eps, nu, p, co: te.sum( + data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci] + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + inverse = te.compute( + (m, m, P, CO), + lambda vh, vw, p, co: te.sum( + bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] + ), + name="inverse", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + ) + + # output + output = te.compute( + (N, H, W, CO), + lambda n, h, w, co: inverse[ + idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co + ], + name="conv2d_winograd", + ) + + return (inputs, kernel_pack, output) + + +def matmul(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((k, m), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + return (a, b, c) + + +def matmul_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((k, m), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + return [a, b, c] + + +def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A") + b = te.placeholder((m, k), name="B") + k = te.reduce_axis((0, k), name="k") + c = te.compute( + (n, m), + lambda i, j: te.sum(a[i, k] * b[k, j], axis=[k]), + name="C", + ) + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def matmul_relu_fp16(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + a = te.placeholder((n, k), name="A", dtype="float16") + b = te.placeholder((k, m), name="B", dtype="float16") + k = te.reduce_axis((0, k), name="k") + + def f_compute(i, j): + v_a = tir.Cast(dtype="float32", value=a[i, k]) + v_b = tir.Cast(dtype="float32", value=b[k, j]) + return te.sum(v_a * v_b, axis=[k]) + + c = te.compute((n, m), f_compute, name="C") + d = topi.nn.relu(c) # pylint: disable=invalid-name + return (a, b, d) + + +def conv2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + return (x, w, y) + + +def conv2d_nchw_bias_bn_relu( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + co: int, + kh: int, + kw: int, + stride: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + oh = (h + 2 * padding - (kh - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + ow = (w + 2 * padding - (kw - 1) * dilation - 1) // stride + 1 # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + w = te.placeholder((co, ci, kh, kw), name="W") + b = te.placeholder((co, 1, 1), name="B") + bn_scale = te.placeholder((co, 1, 1), name="bn_scale") + bn_offset = te.placeholder((co, 1, 1), name="bn_offset") + y = topi.nn.conv2d_nchw(Input=x, Filter=w, stride=stride, padding=padding, dilation=dilation) + y = te.compute((n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + b[j, 0, 0], name="bias_add") + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] * bn_scale[j, 0, 0], name="bn_mul" + ) + y = te.compute( + (n, co, oh, ow), lambda i, j, k, l: y[i, j, k, l] + bn_offset[j, 0, 0], name="bn_add" + ) + y = topi.nn.relu(y) + return (x, w, b, bn_scale, bn_offset, y) + + +def max_pool2d_nchw( # pylint: disable=invalid-name + n: int, + h: int, + w: int, + ci: int, + padding: int, +) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + x = te.placeholder((n, ci, h, w), name="X") + y = topi.nn.pool2d(x, [2, 2], [1, 1], [1, 1], [padding, padding, padding, padding], "max") + return (x, y) + + +def softmax_mn(m, n) -> Tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid-name + a = te.placeholder((m, n), name="A") + b = topi.nn.softmax(a, axis=1) + + return (a, b) + + +def create_te_workload(name: str, idx: int) -> tir.PrimFunc: + workload_func, params = CONFIGS[name] + return te.create_prim_func(workload_func(*params[idx])) # type: ignore + + +CONFIGS = { + "C1D": ( + conv1d_nlc, + [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), + # (1, 256, 64, 128, 1, 2, 0), + # (1, 256, 64, 64, 1, 1, 0), + # (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), + # (1, 128, 128, 128, 3, 1, 1), + # (1, 64, 256, 512, 3, 2, 1), + # (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), + ], + ), + "C2D": ( + conv2d_nhwc, + [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), + # (1, 56, 56, 64, 128, 3, 2, 1), + # (1, 56, 56, 64, 128, 1, 2, 0), + # (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), + # (1, 28, 28, 128, 256, 3, 2, 1), + # (1, 28, 28, 128, 256, 1, 2, 0), + # (1, 28, 28, 128, 128, 3, 1, 1), + # (1, 14, 14, 256, 512, 3, 2, 1), + # (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "C3D": ( + conv3d_ndhwc, + [ + # Derived from conv2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), + # (1, 16, 56, 56, 64, 128, 3, 2, 1), + # (1, 16, 56, 56, 64, 128, 1, 2, 0), + # (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), + # (1, 16, 28, 28, 128, 256, 3, 2, 1), + # (1, 16, 28, 28, 128, 256, 1, 2, 0), + # (1, 16, 28, 28, 128, 128, 3, 1, 1), + # (1, 16, 14, 14, 256, 512, 3, 2, 1), + # (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "GMM": ( + batch_matmul_nkkm, + [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), + ], + ), + "GRP": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1, 1, 4), + # (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), + # (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0, 1, 4), + # (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), + # (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), + # (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), + # (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), + # (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1, 1, 4), + ], + ), + "DIL": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3, 2), + # (1, 56, 56, 64, 128, 3, 2, 1 , 2), + # (1, 56, 56, 64, 128, 1, 2, 0 , 2), + # (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0, 2), + # (1, 28, 28, 128, 256, 3, 2, 1, 2), + # (1, 28, 28, 128, 256, 1, 2, 0, 2), + # (1, 28, 28, 128, 128, 3, 1, 1, 2), + # (1, 14, 14, 256, 512, 3, 2, 1, 2), + # (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1, 2), + ], + ), + "DEP": ( + depthwise_conv2d_nhwc, + [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), + # (1, 56, 56, 128, 3, 1, 1), + # (1, 56, 56, 128, 3, 2, 1), + # (1, 28, 28, 256, 3, 1, 1), + # (1, 28, 28, 256, 3, 2, 1), + # (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), + ], + ), + "T2D": ( + conv2d_transpose_nhwc, + [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), + ], + ), + "CAP": ( + conv2d_capsule_nhwijc, + [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), + ], + ), + "NRM": ( + norm_bmn, + [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), + ], + ), + "SFM": ( + softmax_mn, + [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + ], + ), + "C2d-BN-RELU": ( + conv2d_nhwc_bn_relu, + [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "TBG": ( + transpose_batch_matmul, + [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + ], + ), +} diff --git a/python/tvm/meta_schedule/testing/test_ansor_cpu.py b/python/tvm/meta_schedule/testing/test_ansor_cpu.py new file mode 100644 index 000000000000..36e42c2ab636 --- /dev/null +++ b/python/tvm/meta_schedule/testing/test_ansor_cpu.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = rpc_config.count_num_servers(allow_missing=False) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] + workload_func = auto_scheduler.register_workload(workload_func) + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ), + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/test_tune_te_cpu.py b/python/tvm/meta_schedule/testing/test_tune_te_cpu.py new file mode 100644 index 000000000000..b48fc4f9a04c --- /dev/null +++ b/python/tvm/meta_schedule/testing/test_tune_te_cpu.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + alloc_repeat=3, + max_workers=ARGS.rpc_workers, + ) + sch: tir.Schedule = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.ReplayTraceConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + ), + runner=runner, + task_name=ARGS.workload, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/python/tvm/meta_schedule/testing/tir_tensor_intrin.py b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py new file mode 100644 index 000000000000..76f1920c2777 --- /dev/null +++ b/python/tvm/meta_schedule/testing/tir_tensor_intrin.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A collection of TIR tensor intrinsics""" +# pylint: disable=missing-function-docstring +import tvm +from tvm import tir +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks +# fmt: off + +@T.prim_func +def tensorcore_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def tensorcore_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([ + C[vi : vi + 16, vj : vj + 16], + A[vi : vi + 16, vk : vk + 16], + B[vj : vj + 16, vk : vk + 16], + ]) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate( + T.tvm_mma_sync( + C.data, + C.elem_offset // 256, + A.data, + A.elem_offset // 256, + B.data, + B.elem_offset // 256, + C.data, + C.elem_offset // 256, + dtype="handle", + ) + ) + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, v0 + i) + C[0] = C[0] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + T.reads([C[0 : 1], A[v0 : v0 + 4], B[v0 : v0 + 4]]) + T.writes([C[0 : 1]]) + T.evaluate(T.call_extern( # pylint: disable=redundant-keyword-arg + "vec4add", + C.data, C.elem_offset, + A.data, A.elem_offset, + B.data, B.elem_offset, + dtype="int32", + )) + +@T.prim_func +def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=1, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=1, scope="wmma.accumulator") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + T.cast(A[vii, vkk], "float32") * T.cast(B[vkk, vjj], + "float32") + + +@T.prim_func +def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, + scope="wmma.accumulator") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([C[vi: vi+16, vj: vj+16], A[vi: vi+16, vk: vk+16], B[vk: vk+16, vj: vj+16]]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + A.data, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), + B.data, B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16), + C.data, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), + dtype="handle")) + + +@T.prim_func +def wmma_load_a_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, + scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, + scope="wmma.matrix_a") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("load"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_a_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a") + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi+16, vj: vj+16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) + + +@T.prim_func +def wmma_load_b_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("load"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_load_b_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared", strides=[s1, s0]) + C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi+16, vj: vj+16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_load_matrix_sync( + C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), A.access_ptr("r"), s1, "row_major", + dtype="handle")) + + +@T.prim_func +def wmma_fill_desc(c: T.handle) -> None: + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("init"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = T.float32(0) + + +@T.prim_func +def wmma_fill_impl(c: T.handle) -> None: + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads([]) + T.writes(C[vi : vi + 16, vj : vj + 16]) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16), T.float32(0), dtype="handle")) + + +@T.prim_func +def wmma_store_desc(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global") + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + for i, j in T.grid(16, 16): + with T.block("store"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + C[vii, vjj] = A[vii, vjj] + + +@T.prim_func +def wmma_store_impl(a: T.handle, c: T.handle) -> None: + s1 = T.var("int32") + s0 = T.var("int32") + A = T.match_buffer(a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator") + C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global", strides=[s1, s0]) + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + T.reads(A[vi: vi + 16, vj: vj + 16]) + T.writes(C[vi: vi+16, vj: vj+16]) + T.evaluate(T.tvm_store_matrix_sync( + A.data, 16, 16, 16, A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16), C.access_ptr("w"), s1, "row_major", + dtype="handle")) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks + +TENSORCORE_WMMA = tir.TensorIntrin.register( + "test.tensorcore.wmma", + tensorcore_desc, + tensorcore_impl, +) + +NEON_DOT = tir.TensorIntrin.register( + "test.neon.dot", + dot_product_desc, + dot_product_impl, +) + +WMMA_SYNC = tir.TensorIntrin.register( + "wmma_sync", + wmma_sync_desc, + wmma_sync_impl, +) + +WMMA_LOAD_A = tir.TensorIntrin.register( + "wmma_load_a", + wmma_load_a_desc, + wmma_load_a_impl, +) + +WMMA_LOAD_B = tir.TensorIntrin.register( + "wmma_load_b", + wmma_load_b_desc, + wmma_load_b_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_fill", + wmma_fill_desc, + wmma_fill_impl, +) + +WMMA_FILL = tir.TensorIntrin.register( + "wmma_store", + wmma_store_desc, + wmma_store_impl, +) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py new file mode 100644 index 000000000000..4f38d7cc98be --- /dev/null +++ b/python/tvm/meta_schedule/tune.py @@ -0,0 +1,721 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""User-facing Tuning API""" + +import logging +import os.path +from typing import Callable, Dict, List, Optional, Union + +import tvm +from tvm import relay +from tvm._ffi.registry import register_func +from tvm.relay import Function as RelayFunc +from tvm.relay.backend.executor_factory import ExecutorFactoryModule +from tvm.ir.base import structural_equal, structural_hash +from tvm.ir.module import IRModule +from tvm.runtime import NDArray +from tvm.target.target import Target +from tvm.te import Tensor, create_prim_func +from tvm.tir import PrimFunc, Schedule + +from .integration import extract_task_from_relay, ApplyHistoryBest +from .builder import Builder, LocalBuilder +from .cost_model import CostModel, XGBModel +from .database import Database, JSONDatabase, TuningRecord +from .feature_extractor import PerStoreFeature +from .measure_callback import MeasureCallback +from .mutator import Mutator +from .postproc import Postproc +from .runner import LocalRunner, Runner +from .schedule_rule import ScheduleRule +from .search_strategy import ( + EvolutionarySearchConfig, + ReplayFuncConfig, + ReplayTraceConfig, +) +from .space_generator import PostOrderApply, SpaceGenerator +from .task_scheduler import RoundRobin, TaskScheduler +from .tune_context import TuneContext + + +logger = logging.getLogger(__name__) + +SearchStrategyConfig = Union[ + ReplayFuncConfig, + ReplayTraceConfig, + EvolutionarySearchConfig, +] +TypeSpaceGenerator = Callable[[], SpaceGenerator] +TypeScheduleRule = Callable[[], List[ScheduleRule]] +TypePostproc = Callable[[], List[Postproc]] +TypeMutatorProb = Callable[[], Dict[Mutator, float]] +TypeTaskScheduler = Callable[ + [ + List[TuneContext], + Builder, + Runner, + Database, + CostModel, + List[MeasureCallback], + ], + TaskScheduler, +] + + +class DefaultLLVM: + """Default tuning configuration for LLVM.""" + + @staticmethod + def _sch_rules() -> List[ScheduleRule]: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + schedule_rule as M, + ) + + return [ + M.AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + M.MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=M.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=64, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + M.RandomComputeLocation(), + ] + + @staticmethod + def _postproc() -> List[Postproc]: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + postproc as M, + ) + + return [ + M.DisallowDynamicLoop(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + ] + + @staticmethod + def _mutator_probs() -> Dict[Mutator, float]: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + mutator as M, + ) + + return { + M.MutateTileSize(): 0.9, + M.MutateComputeLocation(): 0.05, + M.MutateUnroll(): 0.03, + M.MutateParallel(max_jobs_per_core=16): 0.02, + } + + +class DefaultCUDA: + """Default tuning configuration for CUDA.""" + + @staticmethod + def _sch_rules() -> List[ScheduleRule]: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + schedule_rule as M, + ) + + return [ + M.MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=M.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=M.ReuseType( + req="must", + levels=[3], + scope="local", + ), + ), + M.AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + M.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + ] + + @staticmethod + def _postproc() -> List[Postproc]: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + postproc as M, + ) + + return [ + M.DisallowDynamicLoop(), + M.RewriteCooperativeFetch(), + M.RewriteUnboundBlock(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + M.VerifyGPUCode(), + ] + + @staticmethod + def _mutator_probs() -> Dict[Mutator, float]: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + mutator as M, + ) + + return { + M.MutateTileSize(): 0.9, + M.MutateUnroll(): 0.1, + } + + +class Parse: + """Parse tuning configuration from user inputs.""" + + @staticmethod + @register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest + def _mod(mod: Union[PrimFunc, IRModule]) -> IRModule: + if isinstance(mod, PrimFunc): + mod = mod.with_attr("global_symbol", "main") + mod = mod.with_attr("tir.noalias", True) + mod = IRModule({"main": mod}) + if not isinstance(mod, IRModule): + raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") + # in order to make sure the mod can be found in ApplyHistoryBest + # different func name can cause structural unequal + if "main" not in mod.global_var_map_: + (func_name,) = [global_var for global_var in mod.global_var_map_] + mod = IRModule({"main": mod[func_name]}) + return mod + + @staticmethod + def _target(target: Union[str, Target]) -> Target: + if isinstance(target, str): + target = Target(target) + if not isinstance(target, Target): + raise TypeError(f"Expected `target` to be str or Target, but gets: {target}") + return target + + @staticmethod + def _builder(builder: Optional[Builder]) -> Builder: + if builder is None: + builder = LocalBuilder() + if not isinstance(builder, Builder): + raise TypeError(f"Expected `builder` to be Builder, but gets: {builder}") + return builder + + @staticmethod + def _runner(runner: Optional[Runner]) -> Runner: + if runner is None: + runner = LocalRunner() + if not isinstance(runner, Runner): + raise TypeError(f"Expected `runner` to be Runner, but gets: {runner}") + return runner + + @staticmethod + def _database(database: Union[None, Database], task_name: str, path: str) -> Database: + if database is None: + path_workload = os.path.join(path, f"{task_name}_database_workload.json") + path_tuning_record = os.path.join(path, f"{task_name}_database_tuning_record.json") + logger.info( + "Creating JSONDatabase. Workload at: %s. Tuning records at: %s", + path_workload, + path_tuning_record, + ) + database = JSONDatabase( + path_workload=path_workload, + path_tuning_record=path_tuning_record, + ) + if not isinstance(database, Database): + raise TypeError(f"Expected `database` to be Database, but gets: {database}") + return database + + @staticmethod + def _callbacks( + measure_callbacks: Optional[List[MeasureCallback]], + ) -> List[MeasureCallback]: + if measure_callbacks is None: + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + measure_callback as M, + ) + + return [ + M.AddToDatabase(), + M.RemoveBuildArtifact(), + M.EchoStatistics(), + M.UpdateCostModel(), + ] + if not isinstance(measure_callbacks, (list, tuple)): + raise TypeError( + f"Expected `measure_callbacks` to be List[MeasureCallback], " + f"but gets: {measure_callbacks}" + ) + measure_callbacks = list(measure_callbacks) + for i, callback in enumerate(measure_callbacks): + if not isinstance(callback, MeasureCallback): + raise TypeError( + f"Expected `measure_callbacks` to be List[MeasureCallback], " + f"but measure_callbacks[{i}] is: {callback}" + ) + return measure_callbacks + + @staticmethod + def _cost_model(cost_model: Optional[CostModel]) -> CostModel: + if cost_model is None: + return XGBModel(extractor=PerStoreFeature()) + if not isinstance(cost_model, CostModel): + raise TypeError(f"Expected `cost_model` to be CostModel, but gets: {cost_model}") + return cost_model + + @staticmethod + def _space_generator(space_generator: Optional[TypeSpaceGenerator]) -> SpaceGenerator: + if space_generator is None: + return PostOrderApply() + if callable(space_generator): + space_generator = space_generator() + if not isinstance(space_generator, SpaceGenerator): + raise TypeError( + f"Expected `space_generator` to return SpaceGenerator, " + f"but gets: {space_generator}" + ) + return space_generator + + @staticmethod + def _sch_rules(sch_rules: Optional[TypeScheduleRule], target: Target) -> List[ScheduleRule]: + if callable(sch_rules): + return sch_rules() + if sch_rules is not None: + raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}") + # pylint: disable=protected-access + if target.kind.name == "llvm": + return DefaultLLVM._sch_rules() + if target.kind.name == "cuda": + return DefaultCUDA._sch_rules() + # pylint: enable=protected-access + raise ValueError(f"Unsupported target: {target}") + + @staticmethod + def _postproc(postproc: Optional[TypePostproc], target: Target) -> List[Postproc]: + if callable(postproc): + return postproc() + if postproc is not None: + raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}") + # pylint: disable=protected-access + if target.kind.name == "llvm": + return DefaultLLVM._postproc() + if target.kind.name == "cuda": + return DefaultCUDA._postproc() + # pylint: enable=protected-access + raise ValueError(f"Unsupported target: {target}") + + @staticmethod + def _mutator_probs( + mutator_probs: Optional[TypeMutatorProb], + target: Target, + ) -> Dict[Mutator, float]: + if callable(mutator_probs): + return mutator_probs() + if mutator_probs is not None: + raise TypeError( + f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}" + ) + # pylint: disable=protected-access + if target.kind.name == "llvm": + return DefaultLLVM._mutator_probs() + if target.kind.name == "cuda": + return DefaultCUDA._mutator_probs() + # pylint: enable=protected-access + raise ValueError(f"Unsupported target: {target}") + + @staticmethod + def _tune_context( + tune_context: Optional[TuneContext], + mod: IRModule, + target: Target, + config: SearchStrategyConfig, + task_name: str, + space_generator: Optional[TypeSpaceGenerator], + sch_rules: Optional[TypeScheduleRule], + postprocs: Optional[TypePostproc], + mutator_probs: Optional[TypeMutatorProb], + num_threads: Optional[int], + ) -> TuneContext: + if tune_context is None: + return TuneContext( + mod=mod, + target=target, + # pylint: disable=protected-access + space_generator=Parse._space_generator(space_generator), + search_strategy=config.create_strategy(), + sch_rules=Parse._sch_rules(sch_rules, target), + postprocs=Parse._postproc(postprocs, target), + mutator_probs=Parse._mutator_probs(mutator_probs, target), + # pylint: enable=protected-access + task_name=task_name, + rand_state=-1, + num_threads=num_threads, + ) + if not isinstance(tune_context, TuneContext): + raise TypeError(f"Expected `tune_context` to be TuneContext, but gets: {tune_context}") + return tune_context + + @staticmethod + def _task_scheduler( + task_scheduler: Union[None, TaskScheduler, TypeTaskScheduler], + tasks: List[TuneContext], + builder: Builder, + runner: Runner, + database: Database, + cost_model: CostModel, + measure_callbacks: List[MeasureCallback], + ): + if task_scheduler is None: + return RoundRobin( + tasks=tasks, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + ) + if callable(task_scheduler): + return task_scheduler( + tasks, + builder, + runner, + database, + cost_model, + measure_callbacks, + ) + if not isinstance(task_scheduler, TaskScheduler): + raise TypeError( + f"Expected `task_scheduler` to be TaskScheduler, but gets: {task_scheduler}" + ) + return task_scheduler + + +def tune_tir( + mod: Union[IRModule, PrimFunc], + target: Union[str, Target], + config: SearchStrategyConfig, + work_dir: str, + *, + task_name: str = "main", + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + task_scheduler: Optional[TaskScheduler] = None, + space: Optional[TypeSpaceGenerator] = None, + sch_rules: Optional[TypeScheduleRule] = None, + postprocs: Optional[TypePostproc] = None, + mutator_probs: Optional[TypeMutatorProb] = None, + num_threads: Optional[int] = None, +) -> Optional[Schedule]: + """Tune a TIR IRModule with a given target. + + Parameters + ---------- + mod : Union[IRModule, PrimFunc] + The module to tune. + target : Union[str, Target] + The target to tune for. + config : SearchStrategyConfig + The search strategy config. + task_name : str + The name of the task. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + cost_model : Optional[CostModel] + The cost model to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] + The function to create TuneContext. + f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] + The function to create TaskScheduler. + + Returns + ------- + sch : Optional[Schedule] + The tuned schedule. + """ + + logger.info("Working directory: %s", work_dir) + # pylint: disable=protected-access + mod = Parse._mod(mod) + database = Parse._database(database, task_name, work_dir) + tune_context = Parse._tune_context( + tune_context=None, + mod=mod, + target=Parse._target(target), + config=config, + task_name=task_name, + space_generator=space, + sch_rules=sch_rules, + postprocs=postprocs, + mutator_probs=mutator_probs, + num_threads=num_threads, + ) + task_scheduler = Parse._task_scheduler( + task_scheduler, + [tune_context], + builder=Parse._builder(builder), + runner=Parse._runner(runner), + database=database, + cost_model=Parse._cost_model(cost_model), + measure_callbacks=Parse._callbacks(measure_callbacks), + ) + # pylint: enable=protected-access + task_scheduler.tune() + bests: List[TuningRecord] = database.get_top_k( + database.commit_workload(mod), + top_k=1, + ) + if not bests: + return None + assert len(bests) == 1 + sch = Schedule(mod) + bests[0].trace.apply_to_schedule(sch, remove_postproc=False) + task_scheduler.cost_model.save(os.path.join(work_dir, f"{task_name}.xgb")) + return sch + + +def tune_te( + tensors: List[Tensor], + target: Union[str, Target], + config: SearchStrategyConfig, + work_dir: str, + *, + task_name: str = "main", + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + task_scheduler: Optional[TaskScheduler] = None, + space: Optional[TypeSpaceGenerator] = None, + sch_rules: Optional[TypeScheduleRule] = None, + postprocs: Optional[TypePostproc] = None, + mutator_probs: Optional[TypeMutatorProb] = None, + num_threads: Optional[int] = None, +) -> Optional[Schedule]: + """Tune a TE compute DAG with a given target. + + Parameters + ---------- + tensor : List[Tensor] + The list of input/output tensors of the TE compute DAG. + target : Union[str, Target] + The target to tune for. + config : SearchStrategyConfig + The search strategy config. + task_name : str + The name of the task. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] + The function to create TuneContext. + f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] + The function to create TaskScheduler. + + Returns + ------- + sch : Optional[Schedule] + The tuned schedule. + """ + return tune_tir( + mod=create_prim_func(tensors), + target=target, + config=config, + work_dir=work_dir, + task_name=task_name, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, + space=space, + sch_rules=sch_rules, + postprocs=postprocs, + mutator_probs=mutator_probs, + num_threads=num_threads, + ) + + +def tune_relay( + mod: Union[RelayFunc, IRModule], + target: Union[str, Target], + config: SearchStrategyConfig, + work_dir: str, + *, + params: Optional[Dict[str, NDArray]] = None, + task_name: str = "main", + builder: Optional[Builder] = None, + runner: Optional[Runner] = None, + database: Optional[Database] = None, + cost_model: Optional[CostModel] = None, + measure_callbacks: Optional[List[MeasureCallback]] = None, + task_scheduler: Optional[TaskScheduler] = None, + space: Optional[TypeSpaceGenerator] = None, + sch_rules: Optional[TypeScheduleRule] = None, + postprocs: Optional[TypePostproc] = None, + mutator_probs: Optional[TypeMutatorProb] = None, + num_threads: Optional[int] = None, +) -> ExecutorFactoryModule: + """Tune a TIR IRModule with a given target. + + Parameters + ---------- + mod : Union[RelayFunc, IRModule] + The module to tune. + target : Union[str, Target] + The target to tune for. + config : SearchStrategyConfig + The search strategy config. + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + task_name : str + The name of the task. + work_dir : Optional[str] + The working directory to save intermediate results. + builder : Optional[Builder] + The builder to use. + runner : Optional[Runner] + The runner to use. + database : Optional[Database] + The database to use. + measure_callbacks : Optional[List[MeasureCallback]] + The callbacks used during tuning. + f_tune_context : Optional[TYPE_F_TUNE_CONTEXT] + The function to create TuneContext. + f_task_scheduler : Optional[TYPE_F_TASK_SCHEDULER] + The function to create TaskScheduler. + + Returns + ------- + lib : ExecutorFactoryModule + The built runtime module for the given relay workload. + """ + + logger.info("Working directory: %s", work_dir) + extracted_tasks = extract_task_from_relay(mod, target, params) + # pylint: disable=protected-access + tune_contexts = [] + target = Parse._target(target) + database = Parse._database(database, task_name, work_dir) + # parse the tuning contexts + for task in extracted_tasks: + assert len(task.dispatched) == 1, "Only size 1 dispatched task list is supported for now" + tune_contexts.append( + Parse._tune_context( + tune_context=None, + mod=Parse._mod(task.dispatched[0]), + target=target, + config=config, + task_name=task.task_name, + space_generator=space, + sch_rules=sch_rules, + postprocs=postprocs, + mutator_probs=mutator_probs, + num_threads=num_threads, + ) + ) + # deduplication + logger.info(f"Before task deduplication: {len(tune_contexts)} tasks") + tasks: List[TuneContext] = [] + hashs: List[int] = [] + for i, task in enumerate(tune_contexts): + struct_hash: int = structural_hash(task.mod) + flag: bool = False + if struct_hash in hashs: + for other_task in tune_contexts[i + 1 :]: + if structural_equal(task.mod, other_task.mod): + flag = True + break + if not flag: + tasks.append(task) + hashs.append(struct_hash) + logger.info(f"After task deduplication: {len(tasks)} tasks") + + # parse the task scheduler + task_scheduler = Parse._task_scheduler( + task_scheduler, + tasks, + builder=Parse._builder(builder), + runner=Parse._runner(runner), + database=database, + cost_model=Parse._cost_model(cost_model), + measure_callbacks=Parse._callbacks(measure_callbacks), + ) + # pylint: enable=protected-access + task_scheduler.tune() + with ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + return relay.build(mod, target=target, params=params) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 99b8c7e869cd..196b1c16b6f2 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,13 +16,14 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, Dict, TYPE_CHECKING from tvm import IRModule from tvm._ffi import register_object from tvm.meta_schedule.utils import cpu_count from tvm.runtime import Object from tvm.target import Target +from tvm.tir import PrimFunc from . import _ffi_api @@ -30,6 +31,8 @@ from .space_generator import SpaceGenerator from .search_strategy import SearchStrategy from .schedule_rule import ScheduleRule + from .postproc import Postproc + from .mutator import Mutator @register_object("meta_schedule.TuneContext") @@ -53,6 +56,10 @@ class TuneContext(Object): The search strategy. sch_rules: Optional[List[ScheduleRule]] = None, The schedule rules. + postprocs: Optional[List[Postproc"]] = None, + The postprocessors. + mutator_probs: Optional[Dict[Mutator, float]] + Mutators and their probability mass. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -71,23 +78,31 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] - space_generator: "SpaceGenerator" - search_strategy: "SearchStrategy" - task_name: Optional[str] + space_generator: Optional["SpaceGenerator"] + search_strategy: Optional["SearchStrategy"] + sch_rules: List["ScheduleRule"] + postprocs: List["Postproc"] + mutator_probs: Optional[Dict["Mutator", float]] + task_name: str rand_state: int num_threads: int def __init__( self, mod: Optional[IRModule] = None, + *, target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, sch_rules: Optional[List["ScheduleRule"]] = None, - task_name: Optional[str] = None, + postprocs: Optional[List["Postproc"]] = None, + mutator_probs: Optional[Dict["Mutator", float]] = None, + task_name: str = "main", rand_state: int = -1, num_threads: Optional[int] = None, ): + if isinstance(mod, PrimFunc): + mod = IRModule.from_expr(mod) if num_threads is None: num_threads = cpu_count() @@ -98,6 +113,8 @@ def __init__( space_generator, search_strategy, sch_rules, + postprocs, + mutator_probs, task_name, rand_state, num_threads, diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index ceb5f7210604..0ff08e4e1ac5 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -21,7 +21,7 @@ import shutil from typing import Any, Callable, List, Optional, Union -import psutil # type: ignore +import psutil import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError @@ -33,34 +33,48 @@ @register_func("meta_schedule.cpu_count") def _cpu_count_impl(logical: bool = True) -> int: - return psutil.cpu_count(logical=logical) or 1 - - -def cpu_count(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system - Parameters ---------- logical : bool = True If True, return the number of logical CPUs, otherwise return the number of physical CPUs - Returns ------- cpu_count : int The number of logical or physical CPUs in the system - Note ---- The meta schedule search infra intentionally does not adopt the following convention in TVM: - C++ API `tvm::runtime::threading::MaxConcurrency()` - Environment variable `TVM_NUM_THREADS` or - Environment variable `OMP_NUM_THREADS` - This is because these variables are dedicated to controlling the runtime behavior of generated kernels, instead of the host-side search. Setting these variables may interfere the host-side search with profiling of generated kernels when measuring locally. """ + return psutil.cpu_count(logical=logical) or 1 + + +@register_func("meta_schedule._process_error_message") +def _process_error_message(error_msg: str) -> str: + error_msg_lines = str(error_msg).splitlines() + if len(error_msg_lines) >= 50: + return "\n".join(error_msg_lines[:25] + ["..."] + error_msg_lines[-25:]) + return error_msg + + +def cpu_count(logical: bool = True) -> int: + """Return the number of logical or physical CPUs in the system + Parameters + ---------- + logical : bool = True + If True, return the number of logical CPUs, otherwise return the number of physical CPUs + Returns + ------- + cpu_count : int + The number of logical or physical CPUs in the system + """ return _cpu_count_impl(logical) @@ -69,17 +83,14 @@ def get_global_func_with_default_on_worker( default: Callable, ) -> Callable: """Get the registered global function on the worker process. - Parameters ---------- name : Union[None, str, Callable] If given a string, retrieve the function in TVM's global registry; If given a python function, return it as it is; Otherwise, return `default`. - default : Callable The function to be returned if `name` is None. - Returns ------- result : Callable @@ -107,7 +118,6 @@ def get_global_func_on_rpc_session( extra_error_msg: Optional[str] = None, ) -> PackedFunc: """Get a PackedFunc from the global registry from an RPCSession. - Parameters ---------- session : RPCSession @@ -116,7 +126,6 @@ def get_global_func_on_rpc_session( The name of the PackedFunc extra_error_msg : Optional[str] Extra information to provide in the error message - Returns ------- result : PackedFunc @@ -140,12 +149,10 @@ def remove_build_dir(artifact_path: str) -> None: def _json_de_tvm(obj: Any) -> Any: """Unpack a TVM nested container to a JSON object in python. - Parameters ---------- obj : Any The TVM nested container to be unpacked. - Returns ------- result : Any @@ -193,12 +200,10 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]: def structural_hash(mod: IRModule) -> str: """Get the structural hash of a module. - Parameters ---------- mod : IRModule The module to be hashed. - Returns ------- result : str @@ -212,11 +217,24 @@ def structural_hash(mod: IRModule) -> str: return str(shash) +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + def check_override( derived_class: Any, base_class: Any, required: bool = True, func_name: str = None ) -> Callable: """Check if the derived class has overridden the base class's method. - Parameters ---------- derived_class : Any @@ -228,7 +246,6 @@ def check_override( func_name : str Name of the method. Default value None, which would be set to substring of the given function, e.g. `f_generate`->`generate`. - Returns ------- func : Callable @@ -250,17 +267,3 @@ def inner(func: Callable): return func return inner - - -def _get_hex_address(handle: ctypes.c_void_p) -> str: - """Get the hexadecimal address of a handle. - Parameters - ---------- - handle : ctypes.c_void_p - The handle to be converted. - Returns - ------- - result : str - The hexadecimal address of the handle. - """ - return hex(ctypes.cast(handle, ctypes.c_void_p).value) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 09b847a3ba91..2deb4f25d92b 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -271,13 +271,17 @@ def _module_export(module, file_name): # fcompile, addons, kwargs? @register_func("tvm.relay.build") +def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): + target, target_host = Target.check_and_update_host_consist(target, target_host) + return build(mod, target, params=params, mod_name=mod_name).module + + def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): """A wrapper around build which discards the Python GraphFactoryRuntime. This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - target, target_host = Target.check_and_update_host_consist(target, target_host) - return build(mod, target, params=params, mod_name=mod_name).module + return _build_module_no_factory_impl(mod, target, target_host, params, mod_name) def _reconstruct_from_deprecated_options(deprecated_params_target): diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index fc953771bf21..db3261e7a392 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -20,7 +20,7 @@ import synr import tvm.tir -from tvm.runtime import Object +from tvm.runtime import Object, String from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind @@ -483,8 +483,14 @@ def create_loop_info( """ assert self.context and self.node, "call 'exit_scope' before 'enter_scope'" extent = end if begin == 0 else self.context.analyzer.simplify(end - begin) - self.annotations = annotations - self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, self.annotations)) + self.annotations: Mapping[str, Object] = {} + if annotations is not None: + self.annotations = { + key: String(val) if isinstance(val, str) else val + for key, val in annotations.items() + } + + self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations)) @register diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 07ceb29ebf98..fa91fcb0200b 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize -from .function import PrimFunc +from .function import PrimFunc, IndexMap, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index ecbcd837cb72..b41eb97b5948 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -16,18 +16,19 @@ # under the License. """Function data types.""" -from typing import Mapping, Union +import inspect +from typing import Callable, List, Mapping, Union -import tvm._ffi -import tvm.runtime -from tvm.runtime import Object +from tvm._ffi import get_global_func, register_object from tvm.ir import BaseFunc -from .buffer import Buffer -from .expr import Var, PrimExpr +from tvm.runtime import Object, convert + from . import _ffi_api +from .buffer import Buffer +from .expr import PrimExpr, Var -@tvm._ffi.register_object("tir.PrimFunc") +@register_object("tir.PrimFunc") class PrimFunc(BaseFunc): """A function declaration expression. @@ -56,7 +57,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: - x = tvm.runtime.convert(x) if not isinstance(x, Object) else x + x = convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): var = Var(x.name, dtype="handle") param_list.append(var) @@ -67,7 +68,13 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore + _ffi_api.PrimFunc, # type: ignore # pylint: disable=no-member + param_list, + body, + ret_type, + buffer_map, + attrs, + span, ) def with_body(self, new_body, span=None): @@ -141,7 +148,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: func : PrimFunc The new function with parameter specialized """ - return _ffi_api.Specialize(self, param_map) # type: ignore + return _ffi_api.Specialize(self, param_map) # type: ignore # pylint: disable=no-member def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: """Print IRModule into TVMScript @@ -159,6 +166,95 @@ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: script : str The TVM Script of the PrimFunc """ - return tvm._ffi.get_global_func("script.AsTVMScript")( - self, tir_prefix, show_meta - ) # type: ignore + return get_global_func("script.AsTVMScript")(self, tir_prefix, show_meta) # type: ignore + + +@register_object("tir.IndexMap") +class IndexMap(Object): + """A mapping from multi-dimensional indices to another set of multi-dimensional indices + + Parameters + ---------- + src_iters : list of Var + The source indices + tgt_iters : list of PrimExpr + The target indices + """ + + src_iters: List[Var] + """The source indices""" + + tgt_iters: List[PrimExpr] + """The target indices""" + + def __init__(self, src_iters: List[Var], tgt_iters: List[PrimExpr]): + self._init_handle_by_constructor( + _ffi_api.IndexMap, # type: ignore # pylint: disable=no-member + src_iters, + tgt_iters, + ) + + def apply(self, indices: List[PrimExpr]) -> List[PrimExpr]: + """Apply the index map to a set of indices + + Parameters + ---------- + indices : List[PriExpr] + The indices to be mapped + + Returns + ------- + result : List[PrimExpr] + The mapped indices + """ + return _ffi_api.IndexMapApply(self, indices) # type: ignore # pylint: disable=no-member + + @staticmethod + def from_func(func: Callable) -> "IndexMap": + """Create an index map from a function + + Parameters + ---------- + func : Callable + The function to map from source indices to target indices + """ + + def wrap(args: List[Var]) -> List[PrimExpr]: + result = func(*args) + if isinstance(result, tuple): + return list(result) + if not isinstance(result, list): + result = [result] + return result + + ndim = len(inspect.signature(func).parameters) + return _ffi_api.IndexMapFromFunc(ndim, wrap) # type: ignore # pylint: disable=no-member + + +@register_object("tir.TensorIntrin") +class TensorIntrin(Object): + """A function declaration expression. + + Parameters + ---------- + desc_func: PrimFunc + The function to describe the computation + + intrin_func: PrimFunc + The function for execution + """ + + def __init__(self, desc_func, intrin_func): + self.__init_handle_by_constructor__( + _ffi_api.TensorIntrin, desc_func, intrin_func # type: ignore # pylint: disable=no-member + ) + + @staticmethod + def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc): + return _ffi_api.TensorIntrinRegister( # pylint: disable=no-member + name, desc_func, intrin_func + ) + + @staticmethod + def get(name: str): + return _ffi_api.TensorIntrinGet(name) # pylint: disable=no-member diff --git a/python/tvm/tir/schedule/__init__.py b/python/tvm/tir/schedule/__init__.py index 5f0e169c43e3..66ac7b9d772b 100644 --- a/python/tvm/tir/schedule/__init__.py +++ b/python/tvm/tir/schedule/__init__.py @@ -22,3 +22,5 @@ from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError from .state import ScheduleDebugMask, ScheduleState from .trace import Trace + +from . import analysis diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py new file mode 100644 index 000000000000..7c0c77a372f3 --- /dev/null +++ b/python/tvm/tir/schedule/analysis.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Analysis used in TensorIR scheduling""" +from typing import List, Optional + +from ..buffer import Buffer +from ..stmt import For +from ..expr import PrimExpr +from ..function import IndexMap + +from . import _ffi_api + + +def suggest_index_map( + buffer: Buffer, + indices: List[PrimExpr], + loops: List[For], + predicate: PrimExpr, +) -> Optional[IndexMap]: + """Provided the access pattern to a buffer, suggest one of the possible layout + transformation to minimize the locality of the access pattern. + + Parameters + ---------- + buffer : Buffer + The buffer to be transformed. + indices : List[PrimExpr] + The access pattern to the buffer. + loops : List[For] + The loops above the buffer. + predicate : PrimExpr + The predicate of the access. + + Returns + ------- + index_map : Optional[IndexMap] + The suggested index map. None if no transformation is suggested. + """ + return _ffi_api.SuggestIndexMap( # type: ignore # pylint: disable=no-member + buffer, + indices, + loops, + predicate, + ) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 50905eed9169..2e70c3e22802 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc +from tvm.tir import Block, For, IntImm, IndexMap, PrimFunc, TensorIntrin +from tvm.tir.expr import FloatImm from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod @@ -369,6 +370,31 @@ def sample_perfect_tile( ) ) + def sample_compute_location( + self, + block: BlockRV, + decision: Optional[int] = None, + ) -> LoopRV: + """Sample a compute-at location of the given block + + Parameters + ---------- + block : BlockRV + The block whose compute-at location is to be sampled + decision : Optional[int] + The sampling decision + + Returns + ------- + result : LoopRV + The sampled loop where the input block is to be computed at + """ + return _ffi_api.ScheduleSampleComputeLocation( # pylint: disable=no-member + self, + block, + decision, + ) + ########## Schedule: Get blocks & loops ########## @type_checked def get_block( @@ -1029,6 +1055,30 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + ########## Schedule: Data movement ########## + + def read_at( + self, + loop: LoopRV, + block: BlockRV, + read_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member + self, loop, block, read_buffer_index, storage_scope + ) + + def write_at( + self, + loop: LoopRV, + block: BlockRV, + write_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member + self, loop, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## @type_checked @@ -1733,9 +1783,16 @@ def after_set_scope( ########## Schedule: Blockize & Tensorize ########## + def blockize(self, loop: LoopRV) -> BlockRV: + return _ffi_api.ScheduleBlockize(self, loop) # pylint: disable=no-member + + def tensorize(self, loop: LoopRV, intrin: Union[str, TensorIntrin]) -> None: + if isinstance(intrin, str): + intrin = String(intrin) + _ffi_api.ScheduleTensorize(self, loop, intrin) # pylint: disable=no-member + ########## Schedule: Annotation ########## - @type_checked def annotate( self, block_or_loop: Union[BlockRV, LoopRV], @@ -1752,45 +1809,6 @@ def annotate( The annotation key ann_val : Union[str, int, float, ExprRV] The annotation value - - Examples - -------- - - Before annotate, in TensorIR, the IR is: - - .. code-block:: python - - @T.prim_func - def before_annotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * 2.0 - - Create the schedule and do annotate: - - .. code-block:: python - - sch = tir.Schedule(before_annotate) - sch.annotate(sch.get_block("B"), "ann_key", "ann_value") - print(sch.mod["main"].script()) - - After applying annotate, the IR becomes: - - .. code-block:: python - - @T.prim_func - def after_annotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"ann_key", "ann_value"}) - B[vi, vj] = A[vi, vj] * 2.0 - """ if isinstance(ann_val, str): ann_val = String(ann_val) @@ -1798,11 +1816,10 @@ def after_annotate(a: T.handle, b: T.handle) -> None: ann_val = IntImm("int32", ann_val) elif isinstance(ann_val, float): ann_val = FloatImm("float32", ann_val) - _ffi_api.ScheduleAnnotate( # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleAnnotate( # pylint: disable=no-member self, block_or_loop, ann_key, ann_val ) - @type_checked def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: """Unannotate a block/loop's annotation with key ann_key @@ -1812,48 +1829,83 @@ def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> Non The block/loop to be unannotated ann_key : str The annotation key + """ + _ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member + + ########## Schedule: Layout transformation ########## + + def transform_layout( + self, + block: BlockRV, + buffer_index: int, + is_write_index: bool, + index_map: Union[IndexMap, Callable], + ) -> None: + """Apply a transformation represented by IndexMap to buffer + + Parameters + ---------- + block_rv : BlockRV + The block that accesses the target buffer + buffer_index: int + The index of the buffer in block's read or write region + is_write_index : bool + Whether the buffer_index is the index of the block's write region + index_map : Union[IndexMap, Callable] + The transformation to apply Examples -------- - Before unannotate, in TensorIR, the IR is: + Before transform_layout, in TensorIR, the IR is: .. code-block:: python @T.prim_func - def before_unannotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) + def before_transform_layout(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"ann_key", "ann_value"}) B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 - Create the schedule and do annotate: + Create the schedule and do transform_layout: .. code-block:: python - sch = tir.Schedule(before_unannotate) - sch.unannotate(sch.get_block("B"), "ann_key") + sch = tir.Schedule(before_storage_align) + sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True, + index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) - After applying unannotate, the IR becomes: + After applying transform_layout, the IR becomes: .. code-block:: python @T.prim_func - def after_unannotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) + def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((8, 8, 16, 16), "float32") + C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * 2.0 - + B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 """ - _ffi_api.ScheduleUnannotate( # type: ignore # pylint: disable=no-member - self, block_or_loop, ann_key + if callable(index_map): + index_map = IndexMap.from_func(index_map) + _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member + self, block, buffer_index, is_write_index, index_map ) ########## Schedule: Misc ########## diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 834335766551..513894cf8c04 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -624,6 +624,18 @@ def PlanAndUpdateBufferAllocationLocation(): return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore +def ApplyBlockBoundPredicate(): + """Narrow the extents of some loops by checking whether some constraints in the block iter + bound predicates can be directly applied on the loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ApplyBlockBoundPredicate() # type: ignore + + def ConvertBlocksToOpaque(): """Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into @@ -749,3 +761,36 @@ def ConvertForLoopsToSerial(): The result pass """ return _ffi_api.ConvertForLoopsToSerial() # type: ignore + + +def InjectSoftwarePipeline(): + """Transform annotated loops into pipelined one that parallelize producers and consumers + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectSoftwarePipeline() # type: ignore + + +def LowerAutoCopy(): + """Automatically do memory optimizations for auto copy blocks + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerAutoCopy() + + +def RenomalizeSplitPattern(): + """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RenormalizeSplitPattern() diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 55a1a5a1830e..f57ed2771654 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -574,6 +574,20 @@ bool IntSet::CanProveNonNegative() const { return false; } +bool IntSet::HasLowerBound() const { + if (const IntervalSetNode* s_int = (*this).as()) { + return s_int->HasLowerBound(); + } + return false; +} + +bool IntSet::HasUpperBound() const { + if (const IntervalSetNode* s_int = (*this).as()) { + return s_int->HasUpperBound(); + } + return false; +} + SignType IntSet::GetSignType() const { if (CanProvePositive()) { return kPositive; @@ -762,6 +776,17 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom return EvalSet(e, ConvertDomMap(dom_map)); } +Array EvalSet(const Array& exprs, const Map& dom_map) { + Array result; + result.reserve(exprs.size()); + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + for (const PrimExpr& e : exprs) { + result.push_back(m.Eval(e)); + } + return result; +} + IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c9d4b1edc3a0..08ad8a77019a 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator { return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); } - IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min, - const PrimExpr& predicate_induced_max) { + IterSumExpr RewriteIterConstraint(const PrimExpr& expr, + const Optional& predicate_induced_min, + const Optional& predicate_induced_max) { return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min, predicate_induced_max); } @@ -494,14 +495,16 @@ class IterMapRewriter : public ExprMutator { * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined. * \return The Normalized expression. */ - IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min, - PrimExpr predicate_induced_max) { + IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional predicate_induced_min, + Optional predicate_induced_max) { // normalize to zero base PrimExpr base = expr->base; if (!is_zero(base)) { expr.CopyOnWrite()->base = 0; - if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base; - if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base; + if (predicate_induced_min.defined()) + predicate_induced_min = predicate_induced_min.value() - base; + if (predicate_induced_max.defined()) + predicate_induced_max = predicate_induced_max.value() - base; } Optional opt = TryFuseIters(expr); ICHECK(!opt.defined() || opt.value()->args.size() == 1); @@ -521,27 +524,28 @@ class IterMapRewriter : public ExprMutator { PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; if (predicate_induced_min.defined()) { - iter_min = max(predicate_induced_min, iter_min); + iter_min = max(predicate_induced_min.value(), iter_min); } if (predicate_induced_max.defined()) { - iter_max = min(predicate_induced_max, iter_max); + iter_max = min(predicate_induced_max.value(), iter_max); } - if (!is_zero(iter_min)) { - // structured form's offset should be updated - flattened_map_.erase(structured_form); - structured_form.CopyOnWrite()->base = -iter_min; - mark.CopyOnWrite()->source = structured_form; - flattened_map_[structured_form] = flattened_form; + if (analyzer_->CanProve(iter_min <= iter_max)) { + if (!is_zero(iter_min)) { + // structured form's offset should be updated + flattened_map_.erase(structured_form); + structured_form.CopyOnWrite()->base = -iter_min; + mark.CopyOnWrite()->source = structured_form; + flattened_map_[structured_form] = flattened_form; + } + mark.CopyOnWrite()->extent = iter_max - iter_min; + sum_fuse_map_[flattened_form] = {mark, iter_min}; + // we need to note down the flattened form of constrained iterators + // to check the validity of constraints, see also CheckConstraints() + constrained_iters_flattened_.push_back(flattened_form); + expr.CopyOnWrite()->args = Array({split}); + expr.CopyOnWrite()->base = base + iter_min; + return expr; } - mark.CopyOnWrite()->extent = iter_max - iter_min; - sum_fuse_map_[flattened_form] = {mark, iter_min}; - - // we need to note down the flattened form of constrained iterators - // to check the validity of constraints, see also CheckConstraints() - constrained_iters_flattened_.push_back(flattened_form); - expr.CopyOnWrite()->args = Array({split}); - expr.CopyOnWrite()->base = base + iter_min; - return expr; } Fail(Diagnostic::Error(expr->span) << "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min @@ -608,7 +612,7 @@ class IterMapRewriter : public ExprMutator { } } } - if (!base_scale) { + if (!base_scale || base_scale.value()->value < 0) { diag_ctx_.Emit(Diagnostic::Error(expr->span) << "Fuse iters failed, can not find a valid base scale"); return NullOpt; @@ -770,14 +774,15 @@ class IterMapRewriter : public ExprMutator { struct IterConstraint { // The expr of the iter PrimExpr iter; - // The expr of the lower_bound - PrimExpr lower_bound; - // The expr of the upper_bound - PrimExpr upper_bound; + // The expr of the lower_bound, maybe undefined + Optional lower_bound; + // The expr of the upper_bound, maybe undefined + Optional upper_bound; // The size of the iter, which is the number of nodes size_t expr_size = 0; - IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size) + IterConstraint(PrimExpr iter, Optional lower_bound, Optional upper_bound, + size_t size) : iter(std::move(iter)), lower_bound(std::move(lower_bound)), upper_bound(std::move(upper_bound)), @@ -787,11 +792,11 @@ struct IterConstraint { /*! * \brief Split the predicate into `(a < b) && (c < d) && ...` * \param pred The predicate to be split. + * \param result The result of predicate split. * \return A list of IterConstraint, empty if the split failed. */ -std::vector MatchBoundConstraints(PrimExpr pred, - const Map& input_iters) { - std::vector result; +bool MatchBoundConstraints(PrimExpr pred, Map& input_iters, + std::vector& result) { arith::PVar lhs, rhs, rest; for (;;) { // try extract comparisions @@ -820,14 +825,14 @@ std::vector MatchBoundConstraints(PrimExpr pred, is_equal = true; is_finish = true; } else { - return std::vector(); + return false; } PrimExpr lhs_expr = lhs.Eval(); PrimExpr rhs_expr = rhs.Eval(); // we only accept predicate of integers if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) && (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) { - return std::vector(); + return false; } // determine iter and bound, if we can not distinguish them simply, // try divide (lhs - rhs) into itervar aware and itervar free parts @@ -863,35 +868,49 @@ std::vector MatchBoundConstraints(PrimExpr pred, lhs_expr = analyzer.Simplify(lhs_expr); rhs_expr = analyzer.Simplify(rhs_expr); } - PrimExpr lower_bound, upper_bound, iter; + Optional lower_bound = NullOpt, upper_bound = NullOpt; + PrimExpr iter; if (is_greater) { if (bound_at_left) { - // bound > iter + // bound > iter / bound >= iter upper_bound = is_equal ? lhs_expr + 1 : lhs_expr; iter = rhs_expr; } else { - // iter > bound + // iter > bound / iter >= bound lower_bound = is_equal ? rhs_expr : rhs_expr + 1; iter = lhs_expr; } } else { if (bound_at_left) { - // bound < iter + // bound < iter / bound <= iter lower_bound = is_equal ? lhs_expr : lhs_expr + 1; iter = rhs_expr; } else { - // iter < bound + // iter < bound / iter <= bound upper_bound = is_equal ? rhs_expr + 1 : rhs_expr; iter = lhs_expr; } } - result.emplace_back(iter, lower_bound, upper_bound, 0); + // If it is a predicate for input iters + if (const auto* var_ptr = iter.as()) { + auto it = input_iters.find(GetRef(var_ptr)); + if (it == input_iters.end()) { + return false; + } + PrimExpr iter_min = (*it).second->min; + PrimExpr iter_max = (*it).second->min + (*it).second->extent; + if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value()); + if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value()); + input_iters.Set(GetRef(var_ptr), Range(iter_min, iter_max)); + } else { + result.emplace_back(iter, lower_bound, upper_bound, 0); + } if (is_finish) { break; } pred = rest.Eval(); } - return result; + return true; } bool IterRangeSanityCheck(const Map& iter_ranges) { @@ -911,8 +930,10 @@ Array DetectIterMap(const Array& indices, const Map(); - std::vector constraints = MatchBoundConstraints(predicate, input_iters); - if (!is_one(predicate) && constraints.empty()) { + Map constrained_input_iters = input_iters; + std::vector constraints; + if (!is_one(predicate) && + !MatchBoundConstraints(predicate, constrained_input_iters, constraints)) { diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Fail to collect constraints from iteration predicate: " << predicate); return Array(); @@ -929,10 +950,11 @@ Array DetectIterMap(const Array& indices, const Map(); } if (!rewriter.CheckConstraints()) { @@ -944,7 +966,10 @@ Array DetectIterMap(const Array& indices, const Map results; for (PrimExpr value : indices) { results.push_back(rewriter.Rewrite(value)); - if (rewriter.unresolved_count() != 0) return Array(); + if (rewriter.unresolved_count() != 0) { + diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Affine mapping detection failed"); + return Array(); + } } // Step1: IterIndependenceChecker checks if the iterator are independent. if (!rewriter.CheckMapping(results, require_bijective)) { @@ -1305,7 +1330,8 @@ class IterMapToExprNormalizer : public ExprMutator { } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) { return floordiv(source, expr->lower_factor) * expr->scale; } else { - return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale; + return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) * + expr->scale; } } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index ac176b2623a3..99f90b9be90e 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctorb); + if (b.is_const()) { + int64_t c2 = b.base; + ICHECK(c2 != 0) << "MathError: the divisor is 0"; + Entry a = VisitExpr(op->a); + int64_t coeff = ZeroAwareGCD(a.coeff, c2); + return Entry(coeff, a.base % c2); + } + return Everything(); + } + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..84473337a452 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); + TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), + c2.Eval()->value > 0); // canonicalization rule // will try rewrite again after canonicalization. @@ -771,6 +773,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -780,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c2.Eval()->value % c1.Eval()->value == 0 && + CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); + TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7dc7b28b968b..79b51caf9090 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -36,6 +36,8 @@ #include #include +#include "../printer/text_printer.h" + namespace tvm { // Register build pipeline related options @@ -187,6 +189,14 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } +Pass Print() { + auto pass_func = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + LOG(INFO) << tir::AsTVMScript(f); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.Print", {}); +} + Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -240,10 +250,14 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ApplyBlockBoundPredicate()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LowerAutoCopy()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); @@ -261,12 +275,14 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); // PHASE 3 + pass_list.push_back(tir::transform::RenormalizeSplitPattern()); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc new file mode 100644 index 000000000000..4faa2d3ba674 --- /dev/null +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -0,0 +1,1309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include +#include +#include +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +/*! \brief Type for multi-dimensional index */ +using MultiIndex = std::vector; +/*! \brief Vector of int64_t */ +using IntVec = std::vector; +/*! \brief Vector of for loops */ +using ForVec = std::vector; + +/*! + * \brief An unordered_map for (for, buffer) => V + * \tparam V The value type + */ +template +using ForBufferMap = std::unordered_map>; + +/*! \brief Given x, compute log2(|x| + 1) */ +inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x + 1); } + +namespace utils { + +/*! + * \brief Get the shape of the buffer + * \param buffer The buffer + * \param analyzer The analyzer + * \return The shape of the buffer + */ +std::vector GetBufferShape(const Buffer& buffer, arith::Analyzer* analyzer) { + int ndim = buffer->shape.size(); + std::vector result; + result.reserve(ndim); + for (const PrimExpr& i : buffer->shape) { + if (const IntImmNode* int_imm = i.as()) { + result.push_back(int_imm->value); + continue; + } + arith::ConstIntBound bound = analyzer->const_int_bound(i); + if (0 <= bound->max_value && bound->max_value < arith::ConstIntBound::kPosInf) { + result.push_back(bound->max_value); + } else { + result.push_back(1); + } + } + return result; +} + +/*! + * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if it exists + * \param loop The loop to be checked + * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist + */ +int64_t GetPragmaAutoUnroll(const ForNode* loop) { + if (Optional auto_unroll = GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { + return auto_unroll.value()->value; + } + return -1; +} + +/*! + * \brief Given a list of loops, return the extent of the first loop if the list is not empty, + * and the first loop has constant extent. Otherwise returns the default value given + * \param loops The list of loops to be checked + * \param default_value The default value to be returned if the list is empty or the first loop + * does not have constant extent + * \return The extent of the first loop if the list is not empty, or the first loop has constant + * extent. Otherwise returns the default value + */ +int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { + if (!loops.empty()) { + if (const int64_t* extent = GetLoopIntExtent(loops[0])) { + return *extent; + } + } + return default_value; +} + +/*! + * \brief Relax each of the multi-indexing pattern according to the domains bound in the analyzer, + * and then union them into a single region + * \param multi_index_pattern A list of multi-index pattern to be relaxed + * \param numel The size of the single region after union + * \param analyzer The analyzer that contains the domain information + * \return The relaxed and unioned region + */ +IntVec RelaxAndUnion(const std::vector& multi_indices, int64_t* numel, + arith::Analyzer* analyzer) { + if (multi_indices.empty()) { + return {}; + } + int n_indices = multi_indices.size(); + int ndim = multi_indices[0].size(); + IntVec access_shape(ndim, 0); + for (int i = 0; i < ndim; ++i) { + int64_t minimum = arith::ConstIntBound::kPosInf; + int64_t maximum = arith::ConstIntBound::kNegInf; + for (int j = 0; j < n_indices; ++j) { + arith::ConstIntBound bound = analyzer->const_int_bound(multi_indices[j][i]); + minimum = std::min(minimum, bound->min_value); + maximum = std::max(maximum, bound->max_value); + } + *numel *= maximum - minimum + 1; + access_shape[i] = maximum - minimum + 1; + } + return access_shape; +} + +/*! + * \brief Given a list of multi-index pattern, return the minimal stride of a variable on it + * \param multi_indices The list of multi-index pattern + * \param buffer_stride The stride of the buffer + * \param var The variable to be checked + * \return The minimal stride of the variable on the multi-index pattern + */ +int64_t GetVarStride(const std::vector& multi_indices, const IntVec& buffer_stride, + const Var& var) { + class CoefficientExtractor : private ExprVisitor { + public: + static int64_t Extract(const PrimExpr& expr, const Var& var) { + CoefficientExtractor extractor(var); + extractor.VisitExpr(expr); + return (extractor.visited_var && !extractor.visited_mul && !extractor.visited_add) + ? 1 + : (extractor.visited_var ? extractor.stride : 0); + } + + private: + explicit CoefficientExtractor(const Var& var) + : var(var), stride(0), visited_var(false), visited_add(false), visited_mul(false) {} + + void VisitExpr_(const MulNode* node) override { + ExprVisitor::VisitExpr_(node); + if (visited_var && !visited_add) { + if (const auto* a = node->a.as()) { + visited_mul = true; + stride = a->value; + } else if (const auto* b = node->b.as()) { + visited_mul = true; + stride = b->value; + } + } + } + + void VisitExpr_(const AddNode* node) override { + ExprVisitor::VisitExpr_(node); + if (visited_var && !visited_mul) { + visited_add = true; + stride = 1; + } + } + + void VisitExpr_(const VarNode* node) override { + if (node == var.get()) { + visited_var = true; + stride = 2; + } + } + + const Var& var; + int64_t stride; + bool visited_var; + bool visited_add; + bool visited_mul; + }; + + constexpr int64_t kNotFound = std::numeric_limits::max(); + int ndim = buffer_stride.size(); + // Calculate the min stride possible + int64_t result = kNotFound; + for (const MultiIndex& multi_index : multi_indices) { + ICHECK_EQ(multi_index.size(), buffer_stride.size()); + // Find the rightest dimension that contains the given variable + for (int i = ndim - 1; i >= 0; --i) { + int64_t coef = CoefficientExtractor::Extract(multi_index[i], var); + if (coef != 0) { + result = std::min(result, std::abs(coef) * buffer_stride[i]); + break; + } + } + } + return (result == kNotFound) ? 0 : result; +} + +/*! + * \brief Converts a 2-dimensional STL vector to a TVM NDArray + * \param src The source 2-dimensional STL vector + * \return The converted TVM NDArray + */ +runtime::NDArray AsNDArray(const std::vector>& src) { + ICHECK(!src.empty()); + int n = src.size(); + int m = src[0].size(); + runtime::NDArray tgt = runtime::NDArray::Empty( + /*shape=*/{n, m}, + /*dtype=*/DLDataType{kDLFloat, 64, 1}, + /*ctx=*/DLDevice{kDLCPU, 0}); + double* data = static_cast(tgt->data); + for (const std::vector& row : src) { + for (double v : row) { + *data++ = v; + } + } + return tgt; +} + +} // namespace utils + +namespace transform { + +/*! + * \brief Create a pass that simplifies the IR for feature extraction + * \return The pass created + */ +Pass SimplifyForFeatureExtraction() { + class Simplifier : private StmtExprMutator { + public: + static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); } + + private: + PrimExpr VisitExpr_(const SelectNode* node) final { return make_const(node->dtype, 1.0); } + + PrimExpr VisitExpr_(const VarNode* var) final { + if (unit_vars_.count(GetRef(var))) { + return make_const(var->dtype, 0.0); + } + return GetRef(var); + } + + Stmt VisitStmt_(const ForNode* loop) final { + if (is_zero(loop->min) && is_one(loop->extent) && loop->kind == ForKind::kSerial && + loop->annotations.empty()) { + unit_vars_.insert(loop->loop_var); + return VisitStmt(loop->body); + } else { + return StmtExprMutator::VisitStmt_(loop); + } + } + + std::unordered_set unit_vars_; + }; + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + PrimFuncNode* n = f.CopyOnWrite(); + n->body = Simplifier::Run(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyConstMatrix", {}); +} + +/*! + * \brief Create a list of passes that preprocesses the IR for feature extraction + * \return The list of passes created + */ +Sequential PassListForPerStoreFeature() { + return Sequential({ + tir::transform::SimplifyForFeatureExtraction(), + tir::transform::LowerCrossThreadReduction(), + tir::transform::LowerInitBlock(), + tir::transform::PlanAndUpdateBufferAllocationLocation(), + tir::transform::ApplyBlockBoundPredicate(), + tir::transform::ConvertBlocksToOpaque(), + tir::transform::UnifyThreadBinding(), + tir::transform::CompactBufferAllocation(), + tir::transform::LowerMatchBuffer(), + tir::transform::Simplify(), + }); +} + +} // namespace transform + +/*! \brief A data structure managing loop nests */ +struct LoopNest { + int64_t prod = 1; // The product of the extents of all the loops + ForVec loops; // All the loops + IntVec auto_unroll; // The loops with auto unroll pragma + ForVec parallel; // The loops whose ForKind are kParallel + ForVec vectorize; // The loops whose ForKind are kVectorized + ForVec unroll; // The loops whose ForKind are kUnrolled + ForVec blockIdx_x; // The loops whose ForKind are kThreadBinding to blockIdx.x + ForVec blockIdx_y; // The loops whose ForKind are kThreadBinding to blockIdx.y + ForVec blockIdx_z; // The loops whose ForKind are kThreadBinding to blockIdx.z + ForVec threadIdx_x; // The loops whose ForKind are kThreadBinding to threadIdx.x + ForVec threadIdx_y; // The loops whose ForKind are kThreadBinding to threadIdx.y + ForVec threadIdx_z; // The loops whose ForKind are kThreadBinding to threadIdx.z + ForVec vthread; // The loops whose ForKind are kThreadBinding to vthread.* + + /*! + * \brief Push a new loop into the loop nest + * \param loop The loop to be pushed + * \param auto_unroll_attr The auto unroll attribute of the loop + * \return A list of for loops that the loop is bound to + */ + ForVec* Push(const ForNode* loop, int64_t* auto_unroll_attr) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + this->prod *= *extent; + } + this->loops.push_back(loop); + if ((*auto_unroll_attr = utils::GetPragmaAutoUnroll(loop)) > 0) { + this->auto_unroll.push_back(*auto_unroll_attr); + } + ForVec* ref_loops = nullptr; + if (loop->kind == ForKind::kParallel) { + ref_loops = ∥ + } else if (loop->kind == ForKind::kVectorized) { + ref_loops = &vectorize; + } else if (loop->kind == ForKind::kUnrolled) { + ref_loops = &unroll; + } else if (loop->kind == ForKind::kThreadBinding) { + std::string thread_tag = loop->thread_binding.value()->thread_tag; + if (thread_tag == "blockIdx.x") { + ref_loops = &blockIdx_x; + } else if (thread_tag == "blockIdx.y") { + ref_loops = &blockIdx_y; + } else if (thread_tag == "blockIdx.z") { + ref_loops = &blockIdx_z; + } else if (thread_tag == "threadIdx.x") { + ref_loops = &threadIdx_x; + } else if (thread_tag == "threadIdx.y") { + ref_loops = &threadIdx_y; + } else if (thread_tag == "threadIdx.z") { + ref_loops = &threadIdx_z; + } else if (support::StartsWith(thread_tag, "vthread")) { + ref_loops = &vthread; + } else { + LOG(FATAL) << "ValueError: Unable to recognize thread tag: " << thread_tag; + } + } + if (ref_loops != nullptr) { + ref_loops->push_back(loop); + } + return ref_loops; + } + + /*! + * \brief Pop the last loop from the loop nest + * \param loop The loop to be popped + * \param ref_loops The list of for loops that the loop is bound to + * \param auto_unroll_attr The auto unroll attribute of the loop + */ + void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) { + if (ref_loops) { + ref_loops->pop_back(); + } + if (auto_unroll_attr > 0) { + this->auto_unroll.pop_back(); + } + if (const int64_t* extent = GetLoopIntExtent(loop)) { + this->prod /= *extent; + } + this->loops.pop_back(); + } +}; + +/****** Group 1: Computation related features ******/ + +namespace group1 { + +/*! \brief Group 1 features */ +struct Feature { + /*! \brief Arithmetic features */ + struct ArithOps { + // Float-point arithmetic features + int64_t float_mad = 0; // The number of float MAD (Multiply–add) ops + int64_t float_add_sub = 0; // The number of float add and sub ops + int64_t float_mul = 0; // The number of float multiply ops + int64_t float_div_mod = 0; // The number of float div and mod ops + int64_t float_cmp = 0; // The number of float comparison ops + int64_t float_math_func = 0; // The number of float math func calls + int64_t float_other_func = 0; // The number of other float func calls + // Integer arithmetic features + int64_t int_mad = 0; // The number of integer MAD (Multiply–add) ops + int64_t int_add_sub = 0; // The number of integer add and sub ops + int64_t int_mul = 0; // The number of integer multiply ops + int64_t int_div_mod = 0; // The number of integer div and mod ops + int64_t int_cmp = 0; // The number of integer comparison ops + int64_t int_math_func = 0; // The number of integer math func calls + int64_t int_other_func = 0; // The number of other integer func calls + // Other arithmetic features + int64_t bool_op = 0; // The number of bool ops + int64_t select_op = 0; // The number of select ops + + static constexpr int64_t kCount = 16; + + ArithOps() = default; + ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent); + + void Export(std::vector* v) const { + double vs[] = { + slog(float_mad), slog(float_add_sub), slog(float_mul), slog(float_div_mod), + slog(float_cmp), slog(float_math_func), slog(float_other_func), // + slog(int_mad), slog(int_add_sub), slog(int_mul), slog(int_div_mod), + slog(int_cmp), slog(int_math_func), slog(int_other_func), // + slog(bool_op), slog(select_op), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + }; + + /*! \brief Loop binding features */ + struct ForKindFeature { + enum class Pos : int { + kPosNone = 0, // Does not have this kind of annotation + kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator + kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator + kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator + kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator + kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator + kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator + kPosMixed = 7, // The annotated iterator is a mixed space and reduce iterator + kEnd = 8, + }; + int64_t num = 0; // The number of iterators with the annotation + int64_t prod = 0; // The product of the lengths of iterators with the annotation + int64_t len = 0; // The length of the innermost iterator with the annotation + Pos pos = Pos::kPosMixed; // The position of the iterators with the annotation + + static constexpr int64_t kCount = 11; + + explicit ForKindFeature(const ForVec& loops); + + void Export(std::vector* v) const { + double vs[] = { + slog(num), + slog(prod), + slog(len), + static_cast(static_cast(pos) == 0), + static_cast(static_cast(pos) == 1), + static_cast(static_cast(pos) == 2), + static_cast(static_cast(pos) == 3), + static_cast(static_cast(pos) == 4), + static_cast(static_cast(pos) == 5), + static_cast(static_cast(pos) == 6), + static_cast(static_cast(pos) == 7), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + }; + + ArithOps arith_ops; // Arithmetic features + ForKindFeature vectorize; // Loop binding features: kVectorize + ForKindFeature unroll; // Loop binding features: kUnroll + ForKindFeature parallel; // Loop binding features: kParallel + bool is_gpu = false; // If the program is running on GPU + int64_t blockIdx_x_len = 1; // The length of blockIdx.x + int64_t blockIdx_y_len = 1; // The length of blockIdx.y + int64_t blockIdx_z_len = 1; // The length of blockIdx.z + int64_t threadIdx_x_len = 1; // The length of threadIdx.x + int64_t threadIdx_y_len = 1; // The length of threadIdx.y + int64_t threadIdx_z_len = 1; // The length of threadIdx.z + int64_t vthread_len = 1; // The length of virtual thread + + static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount * 3 + 8; + + explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, bool is_gpu) + : arith_ops(store, loop_nest.prod), + vectorize(loop_nest.vectorize), + unroll(loop_nest.unroll), + parallel(loop_nest.parallel) { + if (is_gpu) { + this->is_gpu = true; + this->blockIdx_x_len = utils::FirstLoopExtent(loop_nest.blockIdx_x, 1); + this->blockIdx_y_len = utils::FirstLoopExtent(loop_nest.blockIdx_y, 1); + this->blockIdx_z_len = utils::FirstLoopExtent(loop_nest.blockIdx_z, 1); + this->threadIdx_x_len = utils::FirstLoopExtent(loop_nest.threadIdx_x, 1); + this->threadIdx_y_len = utils::FirstLoopExtent(loop_nest.threadIdx_y, 1); + this->threadIdx_z_len = utils::FirstLoopExtent(loop_nest.threadIdx_z, 1); + this->vthread_len = utils::FirstLoopExtent(loop_nest.vthread, 1); + } + } + + void Export(std::vector* v) const { + this->arith_ops.Export(v); + this->vectorize.Export(v); + this->unroll.Export(v); + this->parallel.Export(v); + double vs[] = { + static_cast(is_gpu), // + slog(blockIdx_x_len), slog(blockIdx_y_len), slog(blockIdx_z_len), + slog(threadIdx_x_len), slog(threadIdx_y_len), slog(threadIdx_z_len), + slog(vthread_len), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } +}; + +Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent) { + class ArithOpCounter : public ExprVisitor { + public: +#define TVM_FEATURE_SIMPLE(Type, Counter) \ + void VisitExpr_(const Type* op) final { \ + result_.Counter += this->prod_loop_extent_; \ + ExprVisitor::VisitExpr_(op); \ + } +#define TVM_FEATURE_BINARY(Type, FloatCounter, IntCounter) \ + void VisitExpr_(const Type* op) final { \ + if (op->dtype.is_float()) { \ + result_.FloatCounter += this->prod_loop_extent_; \ + } else { \ + result_.IntCounter += this->prod_loop_extent_; \ + } \ + ExprVisitor::VisitExpr_(op); \ + } + TVM_FEATURE_SIMPLE(AndNode, bool_op); + TVM_FEATURE_SIMPLE(OrNode, bool_op); + TVM_FEATURE_SIMPLE(NotNode, bool_op); + TVM_FEATURE_SIMPLE(SelectNode, select_op); + TVM_FEATURE_BINARY(AddNode, float_add_sub, int_add_sub); + TVM_FEATURE_BINARY(SubNode, float_add_sub, int_add_sub); + TVM_FEATURE_BINARY(MulNode, float_mul, int_mul); + TVM_FEATURE_BINARY(DivNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(ModNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(FloorDivNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(FloorModNode, float_div_mod, int_div_mod); + TVM_FEATURE_BINARY(MaxNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(MinNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(EQNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(NENode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(LTNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(LENode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(GTNode, float_cmp, int_cmp); + TVM_FEATURE_BINARY(GENode, float_cmp, int_cmp); +#undef TVM_FEATURE_BINARY +#undef TVM_FEATURE_SIMPLE + + void VisitExpr_(const CallNode* op) final { + static auto op_call_effect_ = Op::GetAttrMap("TCallEffectKind"); + TCallEffectKind effect_kind = op_call_effect_[Downcast(op->op)]; + bool is_pure = + effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation; + if (is_pure) { + if (op->dtype.is_float()) { + result_.float_math_func += prod_loop_extent_; + } else { + result_.int_math_func += prod_loop_extent_; + } + } else { + if (op->dtype.is_float()) { + result_.float_other_func += prod_loop_extent_; + } else { + result_.int_other_func += prod_loop_extent_; + } + } + ExprVisitor::VisitExpr_(op); + } + + int64_t prod_loop_extent_; + ArithOps result_; + }; + ArithOpCounter counter; + counter.prod_loop_extent_ = prod_loop_extent; + counter(store->value); + *this = counter.result_; +} + +Feature::ForKindFeature::ForKindFeature(const ForVec& loops) { + if (loops.empty()) { + this->num = 0; + this->prod = 0; + this->len = 0; + this->pos = ForKindFeature::Pos::kPosNone; + } else { + const int64_t* last_loop_extent = GetLoopIntExtent(loops.back()); + this->num = loops.size(); + this->len = last_loop_extent ? *last_loop_extent : 1; + this->pos = ForKindFeature::Pos::kPosMixed; + int64_t& prod = this->prod = 1; + for (const ForNode* loop : loops) { + if (const int64_t* extent = GetLoopIntExtent(loop)) { + prod *= *extent; + } + } + } +} + +} // namespace group1 + +namespace group2 { + +/*! \brief Group 2 features */ +struct Feature { + enum class AccessType : int { + kRead = 0, // The buffer is read but not written + kWrite = 1, // The buffer is written but not read + kReadWrite = 2, // The buffer is both read and written + kUnknownRW = 3, // Unknown type + kEnd = 4, + }; + enum class ReuseType : int { + kLoopMultipleRead = 0, // Buffer reuse because accessed on each iteration of a loop + kSerialMultipleReadWrite = 1, // Buffer reuse because it is serially accessed + kNoReuse = 2, // No buffer reuse + kEnd = 3, + }; + + struct SubFeature { + // + const BufferNode* buffer = nullptr; + AccessType access_type = AccessType::kUnknownRW; + std::vector multi_indices = {}; + // + /*! \brief loop_accessed_numel[i][...] means the number of elements accessed by loops[i] */ + std::vector> loop_accessed_numel = {}; + IntVec access_shape; + int64_t num_continuous_bytes = 1; + // Stride information + int64_t min_stride = 0; + int64_t innermost_stride = 0; + int64_t prod_non_strided_loop_extent = 0; + // Reuse information + ReuseType reuse_type = ReuseType::kNoReuse; + double reuse_dis_iter = 0.0; + double reuse_dis_bytes = 0.0; + int64_t reuse_ct = 0; + // Features + double bytes; // The touched memory in bytes + double unique_bytes; // The touched unique memory in bytes + double lines; // The number of touched cache lines + double unique_lines; // The number touched unique cache lines + double bytes_d_reuse_ct; // bytes / reuse_ct + double unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct + double lines_d_reuse_ct; // lines / reuse_ct + double unique_lines_d_reuse_ct; // unique_lines / reuse_ct + double stride; // The stride in access + + static constexpr int64_t kCount = 18; + + void Export(std::vector* v) const { + double vs[] = { + static_cast(static_cast(access_type) == 0), + static_cast(static_cast(access_type) == 1), + static_cast(static_cast(access_type) == 2), + // FeatureSet::BufferAccess::AccessType::kUnknownRW is ignored + slog(bytes), + slog(unique_bytes), + slog(lines), + slog(unique_lines), + static_cast(static_cast(reuse_type) == 0), + static_cast(static_cast(reuse_type) == 1), + static_cast(static_cast(reuse_type) == 2), + slog(reuse_dis_iter), + slog(reuse_dis_bytes), + slog(reuse_ct), + slog(bytes_d_reuse_ct), + slog(unique_bytes_d_reuse_ct), + slog(lines_d_reuse_ct), + slog(unique_lines_d_reuse_ct), + slog(stride), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + static void Pad(std::vector* v) { v->insert(v->end(), 18, 0.0); } + + void SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer); + + void SetReuse(const LoopNest& loop_nest, // + int64_t top_loop_touch_bytes, // + const ForBufferMap& buffer_touched_under_loop); + + void SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes); + + explicit SubFeature(const BufferNode* buffer, AccessType access_type, + std::vector multi_indices, int n_loops) + : buffer(buffer), + access_type(access_type), + multi_indices(multi_indices), + loop_accessed_numel(n_loops) {} + }; + + void Export(std::vector* v, int buffers_per_store) const { + int n = sub_features.size(); + for (int i = 0; i < buffers_per_store; ++i) { + if (i < n) { + sub_features[i].Export(v); + } else { + SubFeature::Pad(v); + } + } + } + + explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, + int64_t cache_line_bytes, IntVec* for_touched_bytes, + ForBufferMap* buffer_touched_under_loop, arith::Analyzer* analyzer); + + void Init(const BufferStoreNode* store, int n_loops); + + void SetRegion(const LoopNest& loop_nest, // + IntVec* for_touched_bytes, // + ForBufferMap* buffer_touched_under_loop, // + arith::Analyzer* analyzer); + + std::vector sub_features; +}; + +void Feature::Init(const BufferStoreNode* store, int n_loops) { + struct Info { + AccessType access_type = AccessType::kUnknownRW; + std::vector multi_indices; + }; + std::unordered_map buffer_info; + { + Info& info = buffer_info[store->buffer.get()]; + info.access_type = AccessType::kWrite; + info.multi_indices.push_back({store->indices.begin(), store->indices.end()}); + } + PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void { + if (const BufferLoadNode* load = obj.as()) { + const BufferNode* buffer = load->buffer.get(); + Info& info = buffer_info[buffer]; + switch (info.access_type) { + case AccessType::kRead: + break; + case AccessType::kWrite: + info.access_type = AccessType::kReadWrite; + break; + case AccessType::kReadWrite: + break; + case AccessType::kUnknownRW: + default: + info.access_type = AccessType::kRead; + break; + } + if (info.access_type != AccessType::kReadWrite) { + info.multi_indices.push_back({load->indices.begin(), load->indices.end()}); + } + } + }); + this->sub_features.reserve(buffer_info.size()); + for (const auto& kv : buffer_info) { + this->sub_features.emplace_back(kv.first, kv.second.access_type, + std::move(kv.second.multi_indices), n_loops); + } +} + +void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, + ForBufferMap* buffer_touched_under_loop, + arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + // Step 1. Initialize and bind all the loop variables to a constant + *for_touched_bytes = IntVec(n_loops, 0); + for (int i = 0; i < n_loops; ++i) { + const ForNode* loop = loops[i]; + analyzer->Bind(loop->loop_var, loop->min, /*allow_override=*/true); + } + // Step 2. Corner case: no loops + if (n_loops == 0) { + // In this case, the `access_shape` is not calculated + for (SubFeature& feature : sub_features) { + feature.access_shape = IntVec(feature.buffer->shape.size(), 1); + } + return; + } + // Step 3. Gradually bind the loops from inner to outer, + // calculate the area the loops touch on each buffer + for (int i = n_loops - 1; i >= 0; --i) { + const ForNode* loop = loops[i]; + analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent), + /*allow_override=*/true); + int64_t& touched_bytes = (*for_touched_bytes)[i] = 0; + for (SubFeature& feature : sub_features) { + const BufferNode* buffer = feature.buffer; + // Note: `feature.access_shape` for `i == 0` is the only one preserved, + // while others are discarded + int64_t numel = 1; + feature.access_shape = utils::RelaxAndUnion(feature.multi_indices, &numel, analyzer); + feature.loop_accessed_numel[i][buffer] = numel; + touched_bytes += numel * buffer->dtype.bytes(); + (*buffer_touched_under_loop)[loop][buffer].push_back(numel); + } + } +} + +void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + // For each buffer, we find the loop stride on it + const BufferNode* buffer = this->buffer; + int ndim = this->buffer->shape.size(); + IntVec buffer_shape = utils::GetBufferShape(GetRef(buffer), analyzer); + // Calculate the buffer's stride from its shape + IntVec buffer_stride(ndim); + if (ndim >= 1) { + buffer_stride[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; --i) { + buffer_stride[i] = buffer_stride[i + 1] * buffer_shape[i + 1]; + } + } + // Calculate `num_continuous_bytes` + { + int64_t& num_continuous_bytes = this->num_continuous_bytes = 1; + const IntVec& access_shape = this->access_shape; + ICHECK_EQ(access_shape.size(), buffer_shape.size()); + for (int i = ndim - 1; i >= 0; --i) { + if (access_shape[i] == buffer_shape[i]) { + // TODO + num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes(); + break; + } + } + } + // Enumerate loops from inner to outer + int i = 0; + // Calculate this->min_stride + int64_t& stride = this->min_stride = 0; + for (i = n_loops - 1; i >= 0; --i) { + stride = utils::GetVarStride(this->multi_indices, buffer_stride, loops[i]->loop_var); + if (stride != 0) { + break; + } + } + // Calculate this->innermost_stride + this->innermost_stride = (i == n_loops - 1) ? stride : 0; + // Calculate this->prod + int64_t& prod = this->prod_non_strided_loop_extent = 1; + for (int j = n_loops - 1; j > i; --j) { + if (const int64_t* extent = GetLoopIntExtent(loops[n_loops - 1])) { // TODO + prod *= *extent; + } + } +} + +void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t top_loop_touch_bytes, + const ForBufferMap& buffer_touched_under_loop) { + const BufferNode* buffer = this->buffer; + // Step 0. Collect all `Var`s that appears in the buffer region + std::unordered_set region_vars; + for (const MultiIndex& multi_index : this->multi_indices) { + for (const PrimExpr& index : multi_index) { + PostOrderVisit(index, [®ion_vars](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { + region_vars.insert(var); + } + }); + } + } + // Default case: no reuse + ReuseType& reuse_type = this->reuse_type = ReuseType::kNoReuse; + double& reuse_dis_iter = this->reuse_dis_iter = 0; + double& reuse_dis_bytes = this->reuse_dis_bytes = 0; + int64_t& reuse_ct = this->reuse_ct = 0; + + // Step 3.2. Enumerate loops from inner to outer, find the first loop with reuse + int n_loops = loop_nest.loops.size(); + const std::vector& loops = loop_nest.loops; + for (int i = n_loops - 1; i >= 0; --i) { + const ForNode* loop = loops[i]; + // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead + if (!region_vars.count(loop->loop_var.get())) { + reuse_type = ReuseType::kLoopMultipleRead; + if (const int64_t* extent = GetLoopIntExtent(loop)) { + reuse_ct = *extent; + } else { + reuse_ct = 1; + } + reuse_dis_iter = 1; + for (int j = n_loops - 1; j > i; --j) { + if (const int64_t* extent = GetLoopIntExtent(loops[j])) { + reuse_dis_iter *= *extent; + } + } + reuse_dis_bytes = 0.0; + if (i == n_loops - 1) { + reuse_dis_bytes = top_loop_touch_bytes; + } else { + for (const auto& iter : buffer_touched_under_loop.at(loops[i + 1])) { + const BufferNode* buffer = iter.first; + const IntVec& numels = iter.second; + int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); + reuse_dis_bytes += numel * buffer->dtype.bytes(); + } + } + break; + } + // Case 2. Find serial reuse, i.e. reuse with kSerialMultipleReadWrite + const IntVec& touched = buffer_touched_under_loop.at(loop).at(buffer); + if (touched.size() >= 2) { + int64_t extent = 1; + if (const int64_t* ext = GetLoopIntExtent(loop)) { + extent = *ext; + } + reuse_type = ReuseType::kSerialMultipleReadWrite; + reuse_ct = touched.size() - 1; + reuse_dis_iter = *std::min_element(touched.begin(), touched.end()); + reuse_dis_bytes = 0.0; + for (const auto& iter : buffer_touched_under_loop.at(loop)) { + const BufferNode* buffer = iter.first; + const IntVec& numels = iter.second; + int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0)); + reuse_dis_bytes += numel * buffer->dtype.bytes(); + } + reuse_dis_iter /= extent; + reuse_dis_bytes /= extent; + break; + } + } +} + +void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes) { + int64_t dtype_bytes = this->buffer->dtype.bytes(); + this->stride = this->innermost_stride; + this->bytes = dtype_bytes * loop_nest.prod; + if (loop_nest.loops.empty()) { + this->unique_bytes = 1; + this->lines = 1; + this->unique_lines = 1; + } else { + this->unique_bytes = this->loop_accessed_numel.front().at(buffer) * dtype_bytes; + this->lines = static_cast(loop_nest.prod) / this->prod_non_strided_loop_extent * + std::min(1.0, 1.0 * this->min_stride * dtype_bytes / cache_line_bytes); + this->lines = std::max(1.0, this->lines); + this->unique_lines = static_cast(this->unique_bytes) / + std::min(cache_line_bytes, this->num_continuous_bytes); + this->unique_lines = std::max(1.0, this->unique_lines); + } + double proxy_reuse_ct = this->reuse_ct > 0 ? this->reuse_ct : 0.5; + this->bytes_d_reuse_ct = this->bytes / proxy_reuse_ct; + this->unique_bytes_d_reuse_ct = this->unique_bytes / proxy_reuse_ct; + this->lines_d_reuse_ct = this->lines / proxy_reuse_ct; + this->unique_lines_d_reuse_ct = this->unique_lines / proxy_reuse_ct; +} + +Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes, + IntVec* for_touched_bytes, ForBufferMap* buffer_touched_under_loop, + arith::Analyzer* analyzer) { + int n_loops = loop_nest.loops.size(); + // Step 0. Initialize data structures + this->Init(store, n_loops); + // Step 1. Calculate region-related feature + this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop, analyzer); + // Step 2. Calculate stride-related feature + for (auto& feature : sub_features) { + feature.SetStride(loop_nest, analyzer); + } + // Step 3. Calculate reuse-related feature + int64_t top_loop_touch_bytes = 0.0; + if (n_loops > 0) { + for (const SubFeature& feature : sub_features) { + int64_t bytes = feature.buffer->dtype.bytes(); + int64_t n_buffer = feature.loop_accessed_numel[0].size(); + top_loop_touch_bytes += bytes * n_buffer; + } + } + for (auto& feature : sub_features) { + feature.SetReuse(loop_nest, top_loop_touch_bytes, *buffer_touched_under_loop); + } + // Step 4. Calculate rest of the features + for (auto& feature : sub_features) { + feature.SetFeature(loop_nest, cache_line_bytes); + } + // Step 5. Sort the features + std::sort(sub_features.begin(), sub_features.end(), [](const SubFeature& a, const SubFeature& b) { + if (a.lines != b.lines) { + return a.lines > b.lines; + } + if (a.bytes != b.bytes) { + return a.bytes > b.bytes; + } + return a.buffer->name < b.buffer->name; + }); +} + +} // namespace group2 + +namespace group3 { + +/*! \brief Group 3 feature */ +struct Feature { + std::vector arith_intensity_curve; + + void Export(std::vector* v) const { + v->insert(v->end(), arith_intensity_curve.begin(), arith_intensity_curve.end()); + } + + explicit Feature(int n_samples, const LoopNest& loop_nest, const IntVec& for_touched_bytes, + const group1::Feature::ArithOps& arith_ops) + : arith_intensity_curve(n_samples, 0.0) { + const std::vector& loops = loop_nest.loops; + ICHECK_EQ(loops.size(), for_touched_bytes.size()); + int n_loops = loops.size(); + // Calculate `memory_bytes` + std::vector memory_bytes; + memory_bytes.resize(n_loops); + for (int i = 0; i < n_loops; ++i) { + memory_bytes[n_loops - 1 - i] = std::log2(for_touched_bytes[i]); + } + // Calculate `compute_ops` and `cur_compute_ops` + std::vector compute_ops; + double total_compute_ops = arith_ops.float_mad + arith_ops.float_add_sub + arith_ops.float_mul + + arith_ops.float_div_mod + arith_ops.float_cmp + + arith_ops.float_math_func + arith_ops.float_other_func; + total_compute_ops /= loop_nest.prod; + for (int i = n_loops - 1; i >= 0; --i) { + if (const int64_t* extent = GetLoopIntExtent(loops[i])) { + total_compute_ops *= *extent; + } + compute_ops.push_back(std::log2(total_compute_ops)); + } + // Fill the feature set + if (total_compute_ops <= 0 || compute_ops.empty()) { + for (int i = 0; i < n_samples; ++i) { + arith_intensity_curve[i] = 0.0; + } + return; + } + total_compute_ops = compute_ops.back(); // i.e. total_compute_ops = log2(total_compute_ops) + int p = 0; + for (int i = 0; i < n_samples; ++i) { + double& result = arith_intensity_curve[i]; + double cur_compute_ops = static_cast(i + 1) / n_samples * total_compute_ops; + // Find the first `p` that `compute[p] >= total * (i + 1) / N` + for (; p < n_loops; ++p) { + if (compute_ops[p] >= cur_compute_ops - 1e-4) { + break; + } + } + CHECK_LT(p, n_loops); + if (p == 0) { + result = compute_ops[p] / memory_bytes[p]; + } else { + double base = compute_ops[p - 1] / memory_bytes[p - 1]; + double slope = + (compute_ops[p] / memory_bytes[p] - compute_ops[p - 1] / memory_bytes[p - 1]) / + (compute_ops[p] - compute_ops[p - 1]); + result = base + slope * (cur_compute_ops - compute_ops[p - 1]); + } + } + } +}; + +} // namespace group3 + +namespace group4 { + +/*! \brief Group 4 feature */ +struct Feature { + int64_t alloc_size = 0; // The size of allocated buffer in bytes + int64_t alloc_prod = 0; // alloc_outer_prod * alloc_inner_prod + int64_t alloc_outer_prod = 1; // The product of lengths of loops outside the scope of the alloc + + static constexpr int64_t kCount = 4; + + void Export(std::vector* v, int64_t outer_prod) const { + double vs[] = { + slog(alloc_size), + slog(alloc_prod), + slog(alloc_outer_prod), + slog(static_cast(outer_prod) / alloc_outer_prod), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + Feature() = default; + + explicit Feature(const LoopNest& loop_nest, const Buffer& buffer, arith::Analyzer* analyzer) { + std::vector shape = utils::GetBufferShape(buffer, analyzer); + int64_t numel = 1; + for (int64_t x : shape) { + numel *= x; + } + alloc_size = numel * buffer->dtype.bytes(); + alloc_prod = numel * loop_nest.prod; + alloc_outer_prod = loop_nest.prod; + } +}; + +} // namespace group4 + +namespace group5 { + +/*! \brief Group 5 feature */ +struct Feature { + int64_t outer_prod; // The product of lengths of outer loops + int num_loops; // The number of outer loops + int auto_unroll_max_step; // The value of pragma "auto_unroll_max_step" + + static constexpr int64_t kCount = 3; + + void Export(std::vector* v) const { + double vs[] = { + slog(outer_prod), + slog(num_loops), + slog(auto_unroll_max_step), + }; + v->insert(v->end(), std::begin(vs), std::end(vs)); + } + + explicit Feature(const LoopNest& loop_nest) { + this->outer_prod = loop_nest.prod; + this->num_loops = loop_nest.loops.size(); + this->auto_unroll_max_step = loop_nest.auto_unroll.empty() ? 0 : loop_nest.auto_unroll.back(); + } +}; + +} // namespace group5 + +/*! \brief The feature extracted */ +struct Feature { + const BufferNode* buffer = nullptr; + int buffer_order = -1; + std::unique_ptr group1 = nullptr; + std::unique_ptr group2 = nullptr; + std::unique_ptr group3 = nullptr; + std::unique_ptr group4 = nullptr; + std::unique_ptr group5 = nullptr; + + bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; } +}; + +/*! \brief The main feature extractor */ +class PerStoreFeatureCollector : private StmtVisitor { + public: + static std::vector Collect(bool is_gpu, int64_t cache_line_bytes, + int64_t arith_intensity_curve_num_samples, + const IRModule& mod) { + PerStoreFeatureCollector collector(is_gpu, cache_line_bytes, arith_intensity_curve_num_samples); + for (const auto& kv : mod->functions) { + if (const PrimFuncNode* func = kv.second.as()) { + collector(func->body); + for (const auto& it : func->buffer_map) { + collector.HandleBufferAlloc(it.second); + } + } + } + std::vector result; + result.reserve(collector.buffer_features_.size()); + for (auto& it : collector.buffer_features_) { + Feature& feature = it.second; + if (feature.buffer != nullptr) { + ICHECK(feature.group1); + ICHECK(feature.group2); + ICHECK(feature.group3); + ICHECK(feature.group5); + if (feature.group4 == nullptr) { + feature.group4 = std::make_unique(); + } + result.push_back(std::move(feature)); + } + } + std::sort(result.begin(), result.end()); + return result; + } + + private: + void VisitStmt_(const ForNode* loop) final { + int64_t auto_unroll; + ForVec* for_vec = loop_nest_.Push(loop, &auto_unroll); + StmtVisitor::VisitStmt_(loop); + loop_nest_.Pop(loop, for_vec, auto_unroll); + } + + void VisitStmt_(const BufferStoreNode* store) final { + if (store->value->IsInstance() || store->value->IsInstance()) { + return; + } + const BufferNode* buffer = store->buffer.get(); + Feature& feature = buffer_features_[buffer]; + if (feature.buffer == nullptr) { + feature.buffer = buffer; + feature.buffer_order = buffer_features_.size(); + } + feature.group1 = std::make_unique(store, loop_nest_, is_gpu_); + feature.group2 = + std::make_unique(store, loop_nest_, cache_line_bytes_, &for_touched_bytes_, + &buffer_touched_under_loop_, &analyzer_); + feature.group3 = + std::make_unique(arith_intensity_curve_num_samples_, loop_nest_, + for_touched_bytes_, feature.group1->arith_ops); + feature.group5 = std::make_unique(loop_nest_); + } + + void VisitStmt_(const BlockNode* block) final { + StmtVisitor::VisitStmt_(block); + for (const Buffer& buffer : block->alloc_buffers) { + HandleBufferAlloc(buffer); + } + } + + void HandleBufferAlloc(const Buffer& buffer) { + Feature& feature = buffer_features_[buffer.get()]; + feature.group4 = std::make_unique(loop_nest_, buffer, &analyzer_); + } + + explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes, + int64_t arith_intensity_curve_num_samples) + : is_gpu_(is_gpu), + cache_line_bytes_(cache_line_bytes), + arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) {} + + bool is_gpu_; + int64_t cache_line_bytes_; + int64_t arith_intensity_curve_num_samples_; + arith::Analyzer analyzer_; + LoopNest loop_nest_ = {}; + IntVec for_touched_bytes_ = {}; + ForBufferMap buffer_touched_under_loop_ = {}; + std::unordered_map buffer_features_ = {}; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +class PerStoreFeatureNode : public FeatureExtractorNode { + public: + int buffers_per_store; + int arith_intensity_curve_num_samples; + int cache_line_bytes; + int feature_vector_length; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("buffers_per_store", &buffers_per_store); + v->Visit("arith_intensity_curve_num_samples", &arith_intensity_curve_num_samples); + v->Visit("cache_line_bytes", &cache_line_bytes); + v->Visit("feature_vector_length", &feature_vector_length); + } + + void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { + static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); + mod = passes(std::move(mod)); + std::vector features = tir::PerStoreFeatureCollector::Collect( + is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod); + int n_features = features.size(); + results->resize(n_features); + for (int i = 0; i < n_features; ++i) { + const tir::Feature& feature = features[i]; + std::vector& result = (*results)[i]; + result.reserve(feature_vector_length); + feature.group1->Export(&result); + feature.group2->Export(&result, this->buffers_per_store); + feature.group3->Export(&result); + feature.group4->Export(&result, feature.group5->outer_prod); + feature.group5->Export(&result); + ICHECK_EQ(static_cast(result.size()), feature_vector_length); + } + } + + Array ExtractFrom(const TuneContext& tune_context, + const Array& candidates) { + bool is_gpu = tune_context->target.value()->kind->name == "cuda"; + std::vector results; + results.resize(candidates.size()); + auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void { + const auto& candidate = candidates[task_id]; + std::vector> features; + ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, &features); + results[task_id] = tir::utils::AsNDArray(features); + }; + support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f); + return results; + } + + static constexpr const char* _type_key = "meta_schedule.PerStoreFeature"; + TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode); +}; + +FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, + int arith_intensity_curve_num_samples, + int cache_line_bytes) { + ObjectPtr n = make_object(); + n->buffers_per_store = buffers_per_store; + n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples; + n->cache_line_bytes = cache_line_bytes; + n->feature_vector_length = tir::group1::Feature::kCount + // + tir::group2::Feature::SubFeature::kCount * buffers_per_store + // + arith_intensity_curve_num_samples + // + tir::group4::Feature::kCount + // + tir::group5::Feature::kCount; + return FeatureExtractor(n); +} + +TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); +TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") + .set_body_typed(FeatureExtractor::PerStoreFeature); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index cf4262814947..f62f5a91d394 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -112,7 +112,22 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Optional> dispatched) { - throw; + ICHECK(dispatched.defined()); + ICHECK_EQ(dispatched.value().size(), 1); + IRModule prim_mod = dispatched.value()[0]; + ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + ICHECK(HasOnlyOneFunction(mod)) << mod; + const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); + prim_mod = (*parse_mod_func)(prim_mod); + if (database->HasWorkload(prim_mod)) { + Array records = database->GetTopK(database->CommitWorkload(prim_mod), 1); + // todo(@zxybazh): check if records always exists when the database has the workload + if (records.size() == 1) { + LOG(INFO) << "Applied history best for " << task_name << "!"; + return records[0]->workload->mod; + } + } + return NullOpt; } /**************** FFI ****************/ @@ -146,6 +161,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { return TaskExtraction(); }); +TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") + .set_body_typed([](Database database) -> ApplyHistoryBest { + return ApplyHistoryBest(database); + }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/measure_callback/update_cost_model.cc b/src/meta_schedule/measure_callback/update_cost_model.cc index 58c86abadfe9..00f6f94eb7d3 100644 --- a/src/meta_schedule/measure_callback/update_cost_model.cc +++ b/src/meta_schedule/measure_callback/update_cost_model.cc @@ -33,7 +33,20 @@ class UpdateCostModelNode : public MeasureCallbackNode { ICHECK(task->measure_candidates.defined()) // << "Task's measure candidates must be present!"; CostModel cost_model = task_scheduler->cost_model.value(); - cost_model->Update(task, task->measure_candidates.value(), runner_results); + ICHECK_EQ(measure_candidates.size(), builder_results.size()); + ICHECK_EQ(runner_results.size(), builder_results.size()); + int n = builder_results.size(); + Array pruned_candidate; + Array pruned_runner_result; + pruned_candidate.reserve(n); + pruned_runner_result.reserve(n); + for (int i = 0; i < n; i++) { + if (!builder_results[i]->error_msg.defined()) { + pruned_candidate.push_back(measure_candidates[i]); + pruned_runner_result.push_back(runner_results[i]); + } + } + cost_model->Update(task, pruned_candidate, pruned_runner_result); } static constexpr const char* _type_key = "meta_schedule.UpdateCostModel"; diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc new file mode 100644 index 000000000000..9c495e1c50cd --- /dev/null +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::InstructionKind; +using tir::Trace; + +/*! \brief A mutator that mutates the compute-at location decision of SampleComputeLocation */ +class MutateComputeLocationNode : public MutatorNode { + public: + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode); + + public: + struct Candidate { + /*! \brief The SampleComputeLocation instruction */ + Instruction inst; + /*! \brief The candidate compute-at locations */ + std::vector locs; + + explicit Candidate(Instruction inst, std::vector locs) + : inst(std::move(inst)), locs(std::move(locs)) {} + }; + + std::vector FindCandidates(const Trace& trace, TRandState* rand_state); + + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! + * \brief Find all appearances of instruction `SampleComputeLocation` whose decision can be mutated + * to at lease one other value + * \param trace The trace from which to find the instructions + * \return All the candidate instructions together with the candidate compute-at locations + */ +std::vector MutateComputeLocationNode::FindCandidates( + const Trace& trace, TRandState* rand_state) { + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/Downcast(LoadJSON(this->json_mod_)), // + /*rand_state=*/ForkSeed(rand_state), // + /*debug_mode=*/0, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + static InstructionKind inst_sample_compute_location = + InstructionKind::Get("SampleComputeLocation"); + std::vector candidates; + + auto f_provide_decision = [&](const tir::Instruction& inst, // + const Array& inputs, // + const Array& attrs, // + const ObjectRef& decision) -> ObjectRef { + if (inst->kind.same_as(inst_sample_compute_location)) { + // Step 1. Extract the instruction input and the old decision. + ICHECK_EQ(inputs.size(), 1); + tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); + int old_decision = Downcast(decision)->value; + // Step 2. Collect all the compute-at locations. + Array location_srefs; + std::vector location_indices; + std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref); + // Step 3. Remove the old decision. + auto it = std::find(location_indices.begin(), location_indices.end(), old_decision); + if (it != location_indices.end()) { + location_srefs.erase(location_srefs.begin() + (it - location_indices.begin())); + location_indices.erase(it); + } + ICHECK_EQ(location_srefs.size(), location_indices.size()); + // Step 4. Add a new candidate if there are at least one remaining compute-at position. + if (!location_srefs.empty()) { + candidates.emplace_back(inst, std::move(location_indices)); + } + } + return decision; + }; + trace->ApplyToSchedule(sch, /*remove_postproc=*/true, f_provide_decision); + return candidates; +} + +Optional MutateComputeLocationNode::Apply(const Trace& trace, TRandState* rand_state) { + std::vector candidates = FindCandidates(trace, rand_state); + if (candidates.empty()) { + return NullOpt; + } + const Candidate& candidate = candidates[tir::SampleInt(rand_state, 0, candidates.size())]; + int loc = candidate.locs[tir::SampleInt(rand_state, 0, candidate.locs.size())]; + return trace->WithDecision(candidate.inst, Integer(loc), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateComputeLocation() { + return Mutator(make_object()); +} + +TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") + .set_body_typed(Mutator::MutateComputeLocation); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc new file mode 100644 index 000000000000..7c973879f2cc --- /dev/null +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if the instruction is annotation with `meta_schedule_parallel` + * \param inst The instruction to be checked + * \return Whether the instruction is annotation with `meta_schedule_parallel` + */ +bool IsAnnotateWithParallel(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_parallel; +} + +/*! + * \brief Replace the annotation value + * \param inst The instruction to be replaced + * \param ann_val The new annotation value + * \return The replaced instruction + */ +Instruction ReplaceAnnValue(Instruction inst, int64_t ann_val) { + ICHECK_EQ(inst->inputs.size(), 2); + return Instruction(/*kind=*/inst->kind, // + /*inputs=*/{inst->inputs[0], Integer(ann_val)}, // + /*attrs=*/inst->attrs, + /*outputs=*/inst->outputs); +} + +/*! + * \brief Get the output of the instruction Get-Block + * \param inst The instruction to be checked + * \return The output of the instruction Get-Block + */ +const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { + static const InstructionKind& inst_get_block = InstructionKind::Get("GetBlock"); + if (!inst->kind.same_as(inst_get_block)) { + return nullptr; + } + ICHECK_EQ(inst->outputs.size(), 1); + const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode); + return block; +} + +/*! + * \brief Analyze the parallel structure + * \param self The schedule state + * \param block_name The name of the root block + * \param func_name The name of the PrimFunc + * \param limit The uplimit of the parallelism + * \return The parallel structure + */ +std::vector> AnalyzeParallel(const ScheduleState& self, + const String& block_name, const String& func_name, + int64_t limit) { + Array block_srefs = tir::GetBlocks(self, block_name, func_name); + ICHECK_EQ(block_srefs.size(), 1); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); + ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); + std::vector> results; + results.reserve(info.realizes.size()); + for (const BlockRealize& realize : info.realizes) { + // Step 1. Extract static loop extents for spatial loops + std::vector loop_extents; + const ForNode* loop = nullptr; + for (const StmtSRefNode* loop_sref = self->stmt2ref.at(realize->block.get())->parent; + (loop = loop_sref->StmtAs()) != nullptr; // + loop_sref = loop_sref->parent) { + int64_t loop_extent = -1; + if (const auto* ext = GetLoopIntExtent(loop)) { + if (!info.non_spatial_vars.count(loop->loop_var.get())) { + loop_extent = *ext; + } + } + if (loop_extent != -1) { + loop_extents.push_back(loop_extent); + } else { + loop_extents.clear(); + } + } + // Step 2. Take the prefix product of loop extents + if (!loop_extents.empty()) { + results.emplace_back(); + std::vector& result = results.back(); + result.reserve(loop_extents.size()); + int64_t prod_extent = 1; + for (auto it = loop_extents.rbegin(); it != loop_extents.rend(); ++it) { + result.push_back(prod_extent *= *it); + if (prod_extent >= limit) { + break; + } + } + } + } + return results; +} + +/*! + * \brief Get the number of parallelizable loops for each subtree + * \param loop_extent_prods The parallel structure for each subtree + * \param limit The uplimit of the parallelism + * \return The number of parallelizable loops for each subtree + */ +std::vector GetNumFusedLoops(const std::vector>& loop_extent_prods, + int64_t limit) { + std::vector results; + results.reserve(loop_extent_prods.size()); + for (const std::vector& prods : loop_extent_prods) { + int n = prods.size(); + int i = std::upper_bound(prods.begin(), prods.end(), limit) - prods.begin(); + if (i > 0 && prods[i - 1] == limit) { + --i; + } + if (i != n) { + ++i; + } + results.push_back(i); + } + return results; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates the parallel extent */ +class MutateParallelNode : public MutatorNode { + public: + /*! + * \brief The maximum number of jobs to be launched per CPU core. + * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int64_t max_jobs_per_core; + /*! \brief The number of cores in CPU. */ + int max_parallel_extent_; + /*! \brief JSON representation of the workload */ + std::string json_mod_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + // `max_parallel_extent_` is not visited. + // `json_mod` is not visited. + } + + static constexpr const char* _type_key = "meta_schedule.MutateParallel"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final { + Target target = context->target.value(); + this->max_parallel_extent_ = GetTargetNumCores(target) * this->max_jobs_per_core; + this->json_mod_ = SaveJSON(context->mod.value()); + } + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief The candidate to be mutated */ +struct MutateParallelNode::Candidate { + /*! \brief The annotation instruction */ + Instruction inst; + /*! \brief The current parallel extent */ + int64_t parallel_extent; + /*! \brief The name of the root block */ + String block_name; + /*! \brief The name of the PrimFunc */ + String func_name; +}; + +/*! + * \brief Get an instruction that annotates the maximum parallel extent + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidate The candidate to be mutated + * \return Whether a decision is found + */ +bool FindParallelDecision(const Trace& trace, TRandState* rand_state, + MutateParallelNode::Candidate* candidate) { + using tir::BlockRVNode; + using tir::InstructionNode; + std::unordered_map get_block_insts; + std::vector ann_insts; + get_block_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (tir::IsAnnotateWithParallel(inst)) { + ann_insts.push_back(inst.get()); + } + if (const BlockRVNode* block_rv = tir::GetInstGetBlockOutput(inst)) { + get_block_insts[block_rv] = inst.get(); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const InstructionNode* get_block_inst = + get_block_insts.at(Downcast(ann_inst->inputs[0]).get()); + ICHECK_EQ(get_block_inst->attrs.size(), 2); + candidate->inst = GetRef(ann_inst); + candidate->parallel_extent = Downcast(ann_inst->inputs[1])->value; + candidate->block_name = Downcast(get_block_inst->attrs[0]); + candidate->func_name = Downcast(get_block_inst->attrs[1]); + return true; +} + +Optional MutateParallelNode::Apply(const Trace& trace, TRandState* rand_state) { + // Step 1. Find a parallel decision. + Candidate candidate; + if (!FindParallelDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + // Step 2. Replay the instructions to recover loop extents + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/Downcast(LoadJSON(this->json_mod_)), // + /*rand_state=*/ForkSeed(rand_state), // + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + // Step 3. Find all possible parallel plans + std::vector> loop_extent_prods = tir::AnalyzeParallel( + sch->state(), candidate.block_name, candidate.func_name, this->max_parallel_extent_); + std::unordered_map> limit2plan; + std::map, int64_t> plan2limit; + for (const std::vector& prods : loop_extent_prods) { + for (int64_t limit : prods) { + if (limit <= this->max_parallel_extent_ && !limit2plan.count(limit)) { + std::vector plan = tir::GetNumFusedLoops(loop_extent_prods, limit); + limit2plan[limit] = plan; + plan2limit[plan] = limit; + } + } + } + // Step 4. Remove the original plan and remove it + std::vector original_plan = + tir::GetNumFusedLoops(loop_extent_prods, candidate.parallel_extent); + auto it = plan2limit.find(original_plan); + if (it != plan2limit.end()) { + plan2limit.erase(it); + } + // Step 5. Pick a new plan + int n_plans = plan2limit.size(); + if (n_plans == 0) { + return NullOpt; + } + it = plan2limit.begin(); + for (int i = 0, n = tir::SampleInt(rand_state, 0, n_plans); i < n; ++i) { + ++it; + } + int64_t limit = it->second; + // Step 6. Assemble a new trace + Array insts; + insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst.same_as(candidate.inst)) { + insts.push_back(tir::ReplaceAnnValue(candidate.inst, limit)); + } else if (inst->kind->IsPostproc()) { + break; + } else { + insts.push_back(inst); + } + } + return Trace(insts, trace->decisions); +} + +Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + return Mutator(n); +} + +TVM_REGISTER_NODE_TYPE(MutateParallelNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel").set_body_typed(Mutator::MutateParallel); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc new file mode 100644 index 000000000000..02c418b3c2c4 --- /dev/null +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::InstructionKind; +using tir::Trace; + +/*! + * \brief Downcast the decision of Sample-Perfect-Tile to an array of integers + * \param decision The decision of Sample-Perfect-Tile + * \return The result of downcast + */ +std::vector DowncastTilingDecision(const ObjectRef& decision) { + const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode); + return support::AsVector(GetRef>(arr)); +} + +/*! + * \brief Calculate the product of elements in an array + * \param array The array + * \return The product of elements in the array + */ +int64_t Product(const std::vector& array) { + int64_t result = 1; + for (int64_t x : array) { + result *= x; + } + return result; +} + +/*! \brief A mutator that mutates the tile size */ +class MutateTileSizeNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateTileSize"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode); + + public: + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! + * \brief Find a sample-perfect-tile decision in the trace + * \param trace The trace + * \param rand_state The random state + * \param inst The instruction selected + * \param decision The decision selected + * \return Whether a decision is found + */ +void FindSamplePerfectTile(const Trace& trace, std::vector* inst, + std::vector>* decision) { + static const InstructionKind& inst_sample_perfect_tile = + InstructionKind::Get("SamplePerfectTile"); + std::vector& instructions = *inst; + std::vector>& decisions = *decision; + instructions.reserve(trace->decisions.size()); + decisions.reserve(trace->decisions.size()); + for (const auto& kv : trace->decisions) { + const Instruction& inst = kv.first; + const ObjectRef& decision = kv.second; + if (inst->kind.same_as(inst_sample_perfect_tile)) { + std::vector tiles = DowncastTilingDecision(decision); + if (tiles.size() >= 2 && Product(tiles) >= 2) { + instructions.push_back(inst); + decisions.push_back(tiles); + } + } + } +} + +void FindSampleVectorize(const Trace& trace, std::vector* inst, + std::vector* decision) { + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + std::vector& instructions = *inst; + std::vector& decisions = *decision; + std::unordered_set annotated; + instructions.reserve(trace->decisions.size()); + decisions.reserve(trace->decisions.size()); + annotated.reserve(trace->decisions.size()); + // Find annotation with `meta_schedule_cooperative_fetch` + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_annotate)) { + ICHECK_EQ(inst->attrs.size(), 1); + ICHECK_EQ(inst->inputs.size(), 2); + if (Downcast(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { + const auto* ann_val = inst->inputs[1].as(); + ICHECK(ann_val); + annotated.insert(ann_val); + } + } + } + // Find sampling instruction that generates the annotation + for (const auto& kv : trace->decisions) { + const Instruction& inst = kv.first; + const ObjectRef& decision = kv.second; + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + if (annotated.count(inst->outputs[0].get())) { + const auto* d = TVM_TYPE_AS(d, decision, IntImmNode); + instructions.push_back(inst); + decisions.push_back(d->value); + } + } + } +} + +struct FactorMemo { + static std::vector Factorize(int n) { + if (const std::vector* result = Global()->Query(n)) { + return *result; + } + std::vector result; + for (int64_t i = 1; i * i < n; ++i) { + if (n % i == 0) { + result.push_back(i); + if (i * i != n) { + result.push_back(n / i); + } + } + } + std::sort(result.begin(), result.end()); + Global()->Add(n, result); + return result; + } + + private: + const std::vector* Query(int n) { + std::unique_lock lock(mutex_); + auto it = memo_.find(n); + if (it != memo_.end()) { + return &it->second; + } + return nullptr; + } + + void Add(int n, std::vector result) { + std::unique_lock lock(mutex_); + memo_.emplace(n, std::move(result)); + } + + static FactorMemo* Global() { + static FactorMemo singleton; + return &singleton; + } + + std::unordered_map> memo_; + std::mutex mutex_; +}; + +Optional MutateSampleTileSize(const Trace& trace, Instruction inst, + std::vector tiles, TRandState* rand_state) { + int n_splits = tiles.size(); + // Step 1. Choose two loops, `x` and `y` + int x, y; + // select source + while (true) { + x = tir::SampleInt(rand_state, 0, n_splits); + if (tiles[x] <= 1) { + continue; + } + y = tir::SampleInt(rand_state, 0, n_splits - 1); + if (y >= x) { + ++y; + } + std::vector factors = FactorMemo::Factorize(tiles[x]); + // Step 2. Choose the divide factor + int64_t divide_factor; + if (y != n_splits - 1) { + divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())]; + } else { + int64_t limit = Downcast(inst->attrs[1])->value; + int max_factor_index = static_cast(factors.size()) - 1; + for (; max_factor_index >= 1; max_factor_index--) { + if (factors[max_factor_index] * tiles[y] <= limit) { + break; + } + } + if (max_factor_index == 0) { + if (n_splits <= 2) { + return NullOpt; + } + // Failed on this dst_idx, try next one. + continue; + } + divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)]; + } + tiles[x] /= divide_factor; + tiles[y] *= divide_factor; + return trace->WithDecision(inst, support::AsArray(tiles), + /*remove_postproc=*/true); + } +} + +Optional MutateSampleVectorize(const Trace& trace, Instruction inst, + int64_t original_decision, TRandState* rand_state) { + ICHECK_EQ(inst->attrs.size(), 2); + std::vector probs = + support::AsVector(Downcast>(inst->attrs[1])); + probs.erase(probs.begin() + original_decision); + int result = tir::MakeMultinomialSampler(rand_state, probs)(); + if (result >= original_decision) { + result += 1; + } + return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true); +} + +Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { + std::vector sample_perfect_tile_insts; + std::vector sample_vectorize_insts; + std::vector> sample_perfect_tile_tiles; + std::vector sample_vectorize_decisions; + FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles); + FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions); + int size_a = sample_perfect_tile_insts.size(); + int size_b = sample_vectorize_insts.size(); + if (size_a == 0 && size_b == 0) { + return NullOpt; + } + int n = tir::SampleInt(rand_state, 0, size_a + size_b); + if (n < size_a) { + return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n], + rand_state); + } else { + n -= size_a; + return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n], + rand_state); + } +} + +Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize").set_body_typed(Mutator::MutateTileSize); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc new file mode 100644 index 000000000000..94e83488584e --- /dev/null +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check if an instruction is annotate with + * `meta_schedule_unroll_explicit` or `meta_schedule_unroll_implicit` + * \param inst The instruction to be checked + * \return Whether the instruction is annotated + */ +bool IsAnnotateWithUnroll(const Instruction& inst) { + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_annotate)) { + return false; + } + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::meta_schedule_unroll_explicit || + ann_key == attr::meta_schedule_unroll_implicit; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Instruction; +using tir::Trace; + +/*! \brief Create a Mutator that mutates auto unroll step */ +class MutateUnrollNode : public MutatorNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "meta_schedule.MutateUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); + + public: + struct Candidate; + // Inherit from `MutatorNode` + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherit from `MutatorNode` + Optional Apply(const Trace& trace, TRandState* rand_state) final; +}; + +/*! \brief A candidate to be mutated */ +struct MutateUnrollNode::Candidate { + /*! \brief The sampling instruction to be mutated */ + Instruction inst; + /*! \brief The probability */ + std::vector probs; + /*! \brief The decision made */ + int decision; +}; + +/*! + * \brief Find the Sample-Categorical instruction to be mutated that affects the maximal unroll step + * \param trace The trace to be mutated + * \param rand_state The random state + * \param candidates The mutation candidate + * \return Whether a decision is found + */ +bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, + MutateUnrollNode::Candidate* candidate) { + using tir::InstructionKind; + using tir::InstructionNode; + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + std::unordered_map sample_insts; + std::vector ann_insts; + sample_insts.reserve(trace->insts.size()); + ann_insts.reserve(trace->insts.size()); + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode); + sample_insts[var_rv] = inst.get(); + } else if (IsAnnotateWithUnroll(inst)) { + ann_insts.push_back(inst.get()); + } + } + int n_ann_insts = ann_insts.size(); + if (n_ann_insts == 0) { + return false; + } + const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)]; + ICHECK_EQ(ann_inst->inputs.size(), 2); + const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode); + ICHECK(sample_insts.count(var_rv)); + const InstructionNode* sample_inst = sample_insts.at(var_rv); + ICHECK_EQ(sample_inst->attrs.size(), 2); + candidate->inst = GetRef(sample_inst); + candidate->decision = + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = + support::AsVector(Downcast>(sample_inst->attrs[1])); + return true; +} + +Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_state) { + Candidate candidate; + if (!FindUnrollDecision(trace, rand_state, &candidate)) { + return NullOpt; + } + if (candidate.probs.size() == 0) { + return NullOpt; + } + candidate.probs.erase(candidate.probs.begin() + candidate.decision); + int result = tir::MakeMultinomialSampler(rand_state, candidate.probs)(); + if (result >= candidate.decision) { + result += 1; + } + return trace->WithDecision(candidate.inst, Integer(result), /*remove_postproc=*/true); +} + +Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } + +TVM_REGISTER_NODE_TYPE(MutateUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc new file mode 100644 index 000000000000..27383adf84e0 --- /dev/null +++ b/src/meta_schedule/mutator/mutator.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Mutator Mutator::PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Mutator(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMutatorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MutatorNode); +TVM_REGISTER_NODE_TYPE(PyMutatorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") + .set_body_method(&MutatorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply") + .set_body_typed([](Mutator self, tir::Trace trace, TRandState seed) -> Optional { + TRandState seed_ = (seed != -1) ? seed : support::LinearCongruentialEngine::DeviceRandom(); + return self->Apply(trace, &seed_); + }); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc new file mode 100644 index 000000000000..715815843a84 --- /dev/null +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Check if the loop is dynamic. */ +struct DynamicExtentFinder : private StmtVisitor { + public: + static bool Find(const IRModule& mod) { + DynamicExtentFinder finder; + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (const auto* prim_func = func.as()) { + finder(prim_func->body); + if (finder.found_) { + return true; + } + } + } + return false; + } + + private: + void VisitStmt_(const ForNode* loop) final { + if (!loop->extent->IsInstance()) { + found_ = true; + } else { + StmtVisitor::VisitStmt_(loop); + } + } + + void VisitStmt(const Stmt& stmt) final { + if (!found_) { + StmtVisitor::VisitStmt(stmt); + } + } + + bool found_ = false; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/*! \brief Check if the IRModule has any loop with non-constant extent. */ +class DisallowDynamicLoopNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + + static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; + TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); +}; + +Postproc Postproc::DisallowDynamicLoop() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") + .set_body_typed(Postproc::DisallowDynamicLoop); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc new file mode 100644 index 000000000000..ff069e2c68cb --- /dev/null +++ b/src/meta_schedule/postproc/postproc.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Postproc Postproc::PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Postproc(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyPostprocNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(PostprocNode); +TVM_REGISTER_NODE_TYPE(PyPostprocNode); + +TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") + .set_body_method(&PostprocNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc new file mode 100644 index 000000000000..279396d03c25 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Parse instruction: sch.bind(..., axis) + * \param sch The schedule + * \param inst The instruction to be parsed + * \param axis The axis name expected + * \return NullOpt if parsing fails; Otherwise, the extent of thread axis + */ +Optional ParseThreadBinding(const Schedule& sch, const Instruction& inst, String axis) { + static InstructionKind inst_kind_bind = InstructionKind::Get("Bind"); + if (!inst->kind.same_as(inst_kind_bind)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 1); + ICHECK_EQ(inst->attrs.size(), 1); + String thread_axis = Downcast(inst->attrs[0]); + if (thread_axis != axis) { + return NullOpt; + } + return Downcast(sch->Get(Downcast(inst->inputs[0]))->extent); +} + +/*! + * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_cooperative_fetch) + * \param sch The schedule + * \param inst The instruction to be parsed + * \param vector_lane The length of vector lane in vectorized cooperative fetching + * \return NullOpt if parsing fails; Otherwise, the annotated block + */ +Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, int* vector_lane) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return NullOpt; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + if (ann_key != attr::meta_schedule_cooperative_fetch) { + return NullOpt; + } + *vector_lane = Downcast(sch->Get(Downcast(inst->inputs[1])))->value; + return Downcast(inst->inputs[0]); +} + +/*! + * \brief Parse instruction: sch.annotate(..., attr::meta_schedule_tensor_core_enabled) + * \param sch The schedule + * \param inst The instruction to be parsed + * \return Whether ths parsing is successful + */ +bool ParseTensorCoreAnn(const Schedule& sch, const Instruction& inst) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return false; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + if (ann_key != attr::meta_schedule_tensor_core_enabled) { + return false; + } + return true; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Rewrite the cooperative fetch annotation to actual vectorized cooperative fetching + * in loop bindings. + */ +class RewriteCooperativeFetchNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); +}; + +bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::Instruction; + using tir::LoopRV; + using tir::Schedule; + using tir::Trace; + Trace trace = sch->trace().value(); + int thread_extent_x = -1; + int thread_extent_y = -1; + int vector_lane = -1; + std::vector> tasks; + for (const Instruction& inst : trace->insts) { + if (Optional new_thread_extent = tir::ParseThreadBinding(sch, inst, "threadIdx.x")) { + thread_extent_x = new_thread_extent.value()->value; + } else if (Optional new_thread_extent = + tir::ParseThreadBinding(sch, inst, "threadIdx.y")) { + thread_extent_y = new_thread_extent.value()->value; + } else if (tir::ParseTensorCoreAnn(sch, inst)) { + thread_extent_x = 32; + } else if (Optional block_rv = tir::ParseAnnotate(sch, inst, &vector_lane)) { + ICHECK_NE(thread_extent_x, -1); + if (vector_lane > 1) { + tasks.push_back([thread_extent_x, thread_extent_y, vector_lane, sch, + block = block_rv.value()]() -> void { + LoopRV fused = sch->GetLoops(block).back(); + if (thread_extent_y == -1) { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_x), // + Integer(vector_lane)}); + sch->Vectorize(split[2]); + sch->Bind(split[1], "threadIdx.x"); + } else { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_y), // + Integer(thread_extent_x), // + Integer(vector_lane)}); + sch->Vectorize(split[3]); + sch->Bind(split[2], "threadIdx.x"); + sch->Bind(split[1], "threadIdx.y"); + sch->StorageAlign(block, 0, -2, 32, 8); + } + }); + } else { + tasks.push_back( + [thread_extent_x, thread_extent_y, sch, block = block_rv.value()]() -> void { + LoopRV fused = sch->GetLoops(block).back(); + if (thread_extent_y == -1) { + Array split = sch->Split(fused, {NullOpt, Integer(thread_extent_x)}); + sch->Bind(split[1], "threadIdx.x"); + } else { + Array split = sch->Split(fused, {NullOpt, // + Integer(thread_extent_y), // + Integer(thread_extent_x)}); + sch->Bind(split[2], "threadIdx.x"); + sch->Bind(split[1], "threadIdx.y"); + sch->StorageAlign(block, 0, -2, 32, 8); + } + }); + } + } + } + for (auto&& task : tasks) { + task(); + } + return true; +} + +Postproc Postproc::RewriteCooperativeFetch() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") + .set_body_typed(Postproc::RewriteCooperativeFetch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc new file mode 100644 index 000000000000..4eca068e17c4 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether the block/loop has any annotation + * \param sref The sref of block/loop + * \return Whether the block/loop has any annotation + */ +inline bool HasAnnOrBinding(const ForNode* loop) { + return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); +} + +class StrideExtractor : public StmtExprVisitor { + public: + static int64_t Extract(const PrimExpr& expr, const Var& var) { + StrideExtractor extractor(var); + extractor.VisitExpr(expr); + return extractor.strides_[expr.get()]; + } + + private: + explicit StrideExtractor(const Var& var) : var_(var) {} + + void VisitExpr_(const MulNode* node) final { + StmtExprVisitor::VisitExpr_(node); + + if (const auto* a = node->a.as()) { + if (strides_.count(node->b.get())) { + strides_[node] = strides_[node->b.get()] * a->value; + } + } else if (const auto* b = node->b.as()) { + if (strides_.count(node->a.get())) { + strides_[node] = strides_[node->a.get()] * b->value; + } + } + } + + void VisitExpr_(const AddNode* node) final { + StmtExprVisitor::VisitExpr_(node); + int64_t stride_a, stride_b; + if (strides_.count(node->a.get())) { + stride_a = strides_[node->a.get()]; + } else { + stride_a = INT64_MAX; + } + if (strides_.count(node->b.get())) { + stride_b = strides_[node->b.get()]; + } else { + stride_b = INT64_MAX; + } + if (stride_a != INT64_MAX || stride_b != INT64_MAX) { + strides_[node] = std::min(stride_a, stride_b); + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_.get()) { + strides_[node] = 1; + } + } + + const Var& var_; + std::unordered_map strides_; +}; + +struct ParsedAnnotation { + int max_parallel_extent; + int max_vectorize_extent; + int unroll_explicit; + int unroll_implicit; + int num_parallel_loops; + int num_vectorize_loops; +}; + +bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { + bool found = false; + *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; + for (const auto& ann : block->annotations) { + if (ann.first == attr::meta_schedule_parallel) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_parallel_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_vectorize) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->max_vectorize_extent = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_explicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_explicit = imm->value; + } + } else if (ann.first == attr::meta_schedule_unroll_implicit) { + found = true; + if (const auto* imm = ann.second.as()) { + parsed->unroll_implicit = imm->value; + } + } + } + return found; +} + +void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { + if (parsed.max_parallel_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_parallel); + } + if (parsed.max_vectorize_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_vectorize); + } + if (parsed.unroll_explicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); + } + if (parsed.unroll_implicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); + } +} + +void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, + const Array& loop_rvs, ParsedAnnotation* parsed) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { + return; + } + int n_loops = loop_rvs.size(); + if (n_loops == 0) { + parsed->max_parallel_extent = -1; + parsed->max_vectorize_extent = -1; + return; + } + // Extract loop_srefs, and calculate the iterator types + Array loop_srefs; + std::vector loop_types; + { + loop_srefs.reserve(n_loops); + loop_types.reserve(n_loops); + for (const LoopRV& loop_rv : loop_rvs) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + loop_types.push_back(GetLoopIterType(loop_srefs.back())); + } + } + // check the maximal number of axes that are vectorizable (contiguous memory access) + BlockRealize realize = GetBlockRealize(sch->state(), block_sref); + Array buffer_access(realize->block->reads); + buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), + realize->block->writes.end()); + std::unordered_map binding_map; + for (size_t i = 0; i < realize->iter_values.size(); i++) { + binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; + } + int max_fusible = INT32_MAX; + // for each block read/write, get the strides of the loop vars and find the fusible + // (vectorizable) axes + for (const BufferRegion& access : buffer_access) { + int fusible = 0; + std::vector strides; + // get strides for each loop var + for (const StmtSRef& loop_sref : loop_srefs) { + int64_t stride = 0, buffer_stride = 1; + const auto* var = loop_sref->StmtAs(); + arith::Analyzer analyzer; + for (int i = access->region.size() - 1; i >= 0; i--) { + PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); + int64_t coef = StrideExtractor::Extract(idx, var->loop_var); + if (coef != 0) { + stride = coef * buffer_stride; + break; + } + buffer_stride *= access->buffer->shape[i].as()->value; + } + strides.push_back(stride); + } + int prev_used_iter = -1; + // check the number of fusible loops + for (int i = strides.size() - 1; i >= 0; i--) { + if (strides[i] == 0) { + // not used in the buffer access, safe to fuse + fusible++; + continue; + } else if (prev_used_iter == -1) { + // the stride of last axis is not 1 means the memory access is not contiguous + if (strides[i] != 1) { + break; + } + fusible++; + prev_used_iter = i; + } else { + // contiguous memory access + const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); + int64_t prev_used_iter_extent = prev_loop->extent.as()->value; + if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + fusible++; + prev_used_iter = i; + } else { + break; + } + } + } + max_fusible = std::min(max_fusible, fusible); + } + // Calculate the parallelize extent + if (parsed->max_parallel_extent != -1) { + int max_extent = parsed->max_parallel_extent; + int& num_fusible = parsed->num_parallel_loops = 0; + int64_t prod_extent = 1; + for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Then we can fuse it in + ++num_fusible; + // Check if we need to break + prod_extent *= *extent; + if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { + break; + } + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Calculate the vectorize extent + if (parsed->max_vectorize_extent != -1) { + int max_extent = parsed->max_vectorize_extent; + int& num_fusible = parsed->num_vectorize_loops = 0; + int64_t prod_extent = 1; + for (int i = n_loops - 1; + i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Cannot vectorize reduce axis + if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { + break; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Check if the extent is still in a good range + prod_extent *= *extent; + if (prod_extent > max_extent) { + break; + } + ++num_fusible; + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { + parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // + n_loops - parsed->num_vectorize_loops); + } +} + +bool FindAnnotateRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + Block block = Downcast(prim_func->body)->block; + if (ParseAnnotation(block, parsed)) { + *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); + RemoveParsedAnn(sch, *root_rv, *parsed); + return true; + } + } + } + return false; +} + +void RewriteParallel(const Schedule& sch, int n, Array* loop_rvs) { + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); + sch->Parallel(fused); + for (int i = 0; i < n; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteVectorize(const Schedule& sch, int n, Array* loop_rvs) { + int n_loops = loop_rvs->size(); + LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); + sch->Vectorize(fused); + for (int i = n_loops - n; i < n_loops; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) { + if (max_step > 0) { + sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); + } +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +class RewriteParallelVectorizeUnrollNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const Schedule& sch) final { + using tir::BlockRV; + using tir::LoopRV; + tir::ParsedAnnotation parsed_root; + BlockRV root_rv{nullptr}; + while (tir::FindAnnotateRootBlock(sch, &parsed_root, &root_rv)) { + for (BlockRV block_rv : sch->GetChildBlocks(root_rv)) { + Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + tir::ParsedAnnotation parsed = parsed_root; + tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]); + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); +}; + +Postproc Postproc::RewriteParallelVectorizeUnroll() { + ObjectPtr n = + make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") + .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc new file mode 100644 index 000000000000..9578d3e6261b --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The visitor that finds all the reduction block to be decomposed */ +struct ReductionBlockFinder : private StmtVisitor { + public: + /*! \brief Find all the reduction blocks that should be decomposed */ + static std::vector> Find(const ScheduleState& self) { + std::vector> results; + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + ReductionBlockFinder finder; + finder(prim_func->body); + for (const BlockNode* block : finder.results_) { + results.emplace_back(self->stmt2ref.at(block), g_var->name_hint); + } + } + } + return results; + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) { + thread_bound_loop_vars_.insert(loop->loop_var.get()); + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) { + results_.push_back(realize->block.get()); + } + StmtVisitor::VisitStmt_(realize); + } + + bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const { + if (thread_bound_loop_vars_.empty()) { + return true; + } + auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); }; + const BlockNode* block = realize->block.get(); + int n = block->iter_vars.size(); + for (int i = 0; i < n; ++i) { + IterVar iter_var = block->iter_vars[i]; + PrimExpr binding = realize->iter_values[i]; + if (iter_var->iter_type == tir::kCommReduce) { + if (UsesVar(binding, f_find)) { + return false; + } + } + } + return true; + } + + /*! \brief The results of the collection */ + std::vector results_; + /*! \brief Loop variables that are bound to threads */ + std::unordered_set thread_bound_loop_vars_; +}; + +/*! + * \brief Find the innermost loop that could be decomposed to + * \param block_sref The block to be decomposed + * \return The index of the innermost loop that could be decomposed + */ +int FindDecomposePoint(const StmtSRef& block_sref) { + Array loop_srefs = GetLoops(block_sref); + int n = loop_srefs.size(); + for (int i = 0; i < n; ++i) { + if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) { + return i; + } + } + return -1; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Rewrite reduction block by moving the init block out */ +class RewriteReductionBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); +}; + +bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) { + for (;;) { + std::vector> results = + tir::ReductionBlockFinder::Find(sch->state()); + int rewritten = 0; + for (const auto& kv : results) { + const tir::StmtSRef& block_sref = kv.first; + const String& global_var_name = kv.second; + int decompose_point = tir::FindDecomposePoint(block_sref); + if (decompose_point == -1) { + continue; + } + tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]); + // If the block is the isolation block of tensor core, + // we mark the init block for later postprocessor to handle the tensorization step + if (HasAnn(block_sref, tir::attr::meta_schedule_auto_tensorize, "wmma_fill")) { + sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize); + sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize); + Array init_inner_block_rv = sch->GetChildBlocks(init_block_rv); + ICHECK_EQ(init_inner_block_rv.size(), 1); + sch->Annotate(init_inner_block_rv[0], tir::attr::meta_schedule_auto_tensorize, + String("wmma_fill")); + } + ++rewritten; + } + if (rewritten == 0) { + break; + } + } + return true; +} + +Postproc Postproc::RewriteReductionBlock() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") + .set_body_typed(Postproc::RewriteReductionBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_tensor_core.cc b/src/meta_schedule/postproc/rewrite_tensor_core.cc new file mode 100644 index 000000000000..68442dec3082 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_tensor_core.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; + +using BlockPosition = std::tuple; + +class RewriteTensorCoreNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RewriteTensorCore"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorCoreNode, PostprocNode); +}; + +void CollectTensorized(const tir::Schedule& sch, const String& func_name, + const tir::PrimFuncNode* func, std::vector& tasks) { + tir::PreOrderVisit( + func->body, + [&](const ObjectRef& obj) -> bool { + if (const auto* block = obj.as()) { + tir::StmtSRef block_sref = sch->GetSRef(block); + if (Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + tasks.push_back(std::make_tuple(block_sref->StmtAs()->name_hint, + func_name, intrin_name.value())); + } + } + return true; + }, + /*visit_init_block=*/false); +} + +bool RewriteTensorCoreNode::Apply(const tir::Schedule& sch) { + std::vector tasks; + for (const auto& kv : sch->mod()->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const tir::PrimFuncNode* prim_func = base_func.as()) { + CollectTensorized(sch, g_var->name_hint, prim_func, tasks); + } + } + for (const BlockPosition& task : tasks) { + // Retrieve the block rv according to the task noted down before + BlockRV block_rv = sch->GetBlock(std::get<0>(task), std::get<1>(task)); + String intrin_name = std::get<2>(task); + sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize); + Optional tiled_loop_rv = TilingwithTensorIntrin(sch, block_rv, intrin_name); + if (!tiled_loop_rv.defined()) continue; + sch->Tensorize(tiled_loop_rv.value(), intrin_name); + } + return true; +} + +Postproc Postproc::RewriteTensorCore() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteTensorCoreNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorCore") + .set_body_typed(Postproc::RewriteTensorCore); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc new file mode 100644 index 000000000000..624e6d27e844 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The rewrite type for an unbound block */ +enum class BindType : int32_t { + /*! \brief No additional thread binding is needed */ + kNoBind = 0, + /*! \brief Need to bind to blockIdx */ + kBindBlock = 1, + /*! \brief Need to bind to both blockIdx and threadIdx */ + kBindBlockThread = 2, +}; + +/*! + * \brief Check the combination of bindings to be added to the block + * \param block_sref The block to be checked + * \param fuse_first_num The number of loops to be fused + * \return The type of binding to be added to the block + */ +BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) { + Array loops = tir::GetLoops(block_sref); + int n = loops.size(); + if (n == 0) { + return BindType::kNoBind; + } + int i_block_idx = -1; + int i_thread_idx = -1; + int i_multi_child = -1; + int i_spatial_loop = -1; + for (int i = 0; i < n; ++i) { + const StmtSRef& loop_sref = loops[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + if (i_block_idx == -1) { + i_block_idx = i; + } + } + if (IsThreadIdx(thread_scope)) { + if (i_thread_idx == -1) { + i_thread_idx = i; + } + } + if (!IsSingleStmt(loop->body)) { + if (i_multi_child == -1) { + i_multi_child = i + 1; + } + } + if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) { + if (i_spatial_loop == i - 1) { + ++i_spatial_loop; + } + } + } + if (i_multi_child == -1) { + i_multi_child = n; + } + if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) { + return BindType::kNoBind; + } else if (i_block_idx != -1 && i_thread_idx == -1) { + ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not"; + throw; + } else if (i_block_idx == -1 && i_thread_idx != -1) { + *fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1); + return BindType::kBindBlock; + } else { // i_block_idx == -1 && i_thread_idx == -1 + *fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1); + return BindType::kBindBlockThread; + } +} + +/*! \brief Find all the blocks that are not bound */ +class UnboundBlockFinder : private StmtVisitor { + public: + static std::vector> Find(const ScheduleState& self) { + UnboundBlockFinder finder(self); + for (const auto& kv : self->mod->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + finder.global_var_name_ = g_var->name_hint; + finder(Downcast(prim_func->body)->block->body); + } + } + return std::move(finder.blocks_); + } + + private: + void VisitStmt_(const ForNode* loop) final { + runtime::ThreadScope thread_scope = GetThreadScope(loop); + if (IsBlockIdx(thread_scope)) { + ++n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + ++n_thread_idx_; + } + if (n_block_idx_ == 0 || n_thread_idx_ == 0) { + StmtVisitor::VisitStmt_(loop); + } + if (IsBlockIdx(thread_scope)) { + --n_block_idx_; + } else if (IsThreadIdx(thread_scope)) { + --n_thread_idx_; + } + } + + void VisitStmt_(const BlockNode* block) final { + blocks_.emplace_back(self_->stmt2ref.at(block), global_var_name_); + } + + explicit UnboundBlockFinder(const ScheduleState& self) + : self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {} + + /*! \brief The schedule state */ + const ScheduleState& self_; + /*! \brief The list of unbound blocks */ + std::vector> blocks_; + /*! \brief The number of blockIdx above the current stmt */ + int n_block_idx_; + /*! \brief The number of threadIdx above the current stmt */ + int n_thread_idx_; + /*! \brief The name of the global var */ + String global_var_name_; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Add thread binding to unbound blocks */ +class RewriteUnboundBlockNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final { + CHECK(context->target.defined()) << "ValueError: target is not defined"; + Optional warp_size = context->target.value()->GetAttr("thread_warp_size"); + CHECK(warp_size.defined()) << "ValueError: missing attribute `thread_warp_size` in the target"; + this->warp_size_ = warp_size.value(); + } + + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final; + + public: + /*! \brief The cached warp size from Target */ + int warp_size_ = -1; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `warp_size_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); +}; + +bool RewriteUnboundBlockNode::Apply(const tir::Schedule& sch) { + using tir::BlockRV; + using tir::LoopRV; + using tir::Schedule; + ICHECK_NE(this->warp_size_, -1); + std::vector> unbound_blocks = + tir::UnboundBlockFinder::Find(sch->state()); + for (const auto& kv : unbound_blocks) { + tir::StmtSRef block_sref = kv.first; + String global_var_name = kv.second; + int fuse_first_num = 0; + tir::BindType bind_type = tir::GetBindType(block_sref, &fuse_first_num); + if (bind_type == tir::BindType::kNoBind) { + continue; + } + BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); + Array loop_rvs = sch->GetLoops(block_rv); + LoopRV fused = sch->Fuse({loop_rvs.begin(), loop_rvs.begin() + fuse_first_num}); + if (bind_type == tir::BindType::kBindBlock) { + sch->Bind(fused, "blockIdx.x"); + } else if (bind_type == tir::BindType::kBindBlockThread) { + Array splits = sch->Split(fused, {NullOpt, Integer(this->warp_size_)}); + ICHECK_EQ(splits.size(), 2); + sch->Bind(splits[0], "blockIdx.x"); + sch->Bind(splits[1], "threadIdx.x"); + } + } + return true; +} + +Postproc Postproc::RewriteUnboundBlock() { + ObjectPtr n = make_object(); + n->warp_size_ = -1; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") + .set_body_typed(Postproc::RewriteUnboundBlock); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc new file mode 100644 index 000000000000..19e5fde06f23 --- /dev/null +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +class ThreadExtentChecker : private StmtVisitor { + public: + static bool Check(const Stmt& stmt) { + try { + ThreadExtentChecker().VisitStmt(stmt); + return true; + } catch (const dmlc::Error& e) { + return false; + } + } + + private: + void VisitStmt_(const ForNode* loop) { + if (IsThreadIdx(GetThreadScope(loop))) { + if (const int64_t* p_ext = GetLoopIntExtent(loop)) { + thread_extent_product *= *p_ext; + StmtVisitor::VisitStmt_(loop); + thread_extent_product /= *p_ext; + return; + } else { + throw dmlc::Error("Dynamic thread extent"); + } + } + StmtVisitor::VisitStmt_(loop); + } + + void VisitStmt_(const BlockNode* block) { + if (Optional low_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_low_inclusive)) { + if (Optional high_inclusive = + GetAnn(block, attr::meta_schedule_thread_extent_high_inclusive)) { + int64_t low = low_inclusive.value()->value; + int64_t high = high_inclusive.value()->value; + if (!(low <= thread_extent_product && thread_extent_product <= high)) { + throw dmlc::Error("Thread extent"); + } + } + } + StmtVisitor::VisitStmt_(block); + } + + int64_t thread_extent_product = 1; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +/*! \brief Verify the correctness of the generated GPU code. */ +Integer Extract(const Target& target, const char* name) { + ICHECK(target.defined()); + if (Optional v = target->GetAttr(name)) { + return v.value(); + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +/*! \brief Verify the correctness of the generated GPU code. */ +class VerifyGPUCodeNode : public PostprocNode { + public: + Map target_constraints_{nullptr}; + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->target_constraints_ = Map{ + {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, + {"max_local_memory_per_block", Extract(target, "registers_per_block")}, + {"max_threads_per_block", Extract(target, "max_threads_per_block")}, + {"max_vthread", Integer(8)}, + {"max_vector_bytes", Integer(16)}}; + } + + bool Verify(const IRModule& mod) const { + for (const auto& kv : mod->functions) { + if (const auto* prim_func = kv.second.as()) { + if (!tir::VerifyGPUCode(GetRef(prim_func), this->target_constraints_)) { + return false; + } + } + } + return true; + } + + bool Apply(const tir::Schedule& sch) final { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + if (!tir::ThreadExtentChecker::Check(prim_func->body)) { + return false; + } + IRModule lowered{nullptr}; + try { + auto pass_list = Array(); + // Phase 1 + // First three passes are not needed in TIR schedule. + // pass_list.push_back(tir::transform::InjectPrefetch()); + // pass_list.push_back(tir::transform::TextureFlatten()); + // pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + pass_list.push_back(tir::transform::LowerCrossThreadReduction()); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ApplyBlockBoundPredicate()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LowerAutoCopy()); + pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::BF16Legalize()); + pass_list.push_back(tir::transform::NarrowDataType(32)); + pass_list.push_back(tir::transform::Simplify()); + + // Phase 2 + pass_list.push_back(tir::transform::VectorizeLoop(true)); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::InjectDoubleBuffer()); + pass_list.push_back(tir::transform::StorageRewrite()); + + // Convert Function to IRModule + transform::PassContext pass_ctx = transform::PassContext::Current(); + tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", + runtime::String(g_var->name_hint)); + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); + } + IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); + lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); + } catch (const dmlc::Error& e) { + return false; + } + if (!Verify(lowered)) { + return false; + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; + TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); +}; + +Postproc Postproc::VerifyGPUCode() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc new file mode 100644 index 000000000000..75bb47c23dc3 --- /dev/null +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class AddRFactorNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->max_parallel_basic_ = GetTargetNumCores(target); + if (this->max_jobs_per_core != -1) { + this->max_parallel_extent_ = max_parallel_basic_ * max_jobs_per_core; + } + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + public: + /*! + * \brief The maximum number of jobs to be launched per core. + * It sets the uplimit of parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int max_jobs_per_core; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The number of uplimit of parallelism. */ + int max_parallel_extent_; + /*! \brief The number of cores. */ + int max_parallel_basic_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `max_parallel_extent_` is not visited + // `max_parallel_basic_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.AddRFactor"; + TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, + Optional max_innermost_factor) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->max_parallel_extent_ = -1; + n->max_parallel_basic_ = -1; + return ScheduleRule(n); +} + +Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, + max_parallel_basic_)) { + return {sch}; + } + + // Make a copy of the original schedule. + tir::Schedule ori_sch = sch->Copy(); + ori_sch->Seed(sch->ForkSeed()); + + // Reorder the loop axes if reduction loops are not innermost. + // After the reordering, fuse all the reduction loops. + size_t num_spatial_loops; + tir::LoopRV fused_reduce_loop; + ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); + + // Split the fused reduction loop. + Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); + const Array& split_loops = + sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + + Array res; + for (const tir::LoopRV& split_loop : split_loops) { + tir::Schedule sch_tmp = sch->Copy(); + sch_tmp->Seed(sch->ForkSeed()); + const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); + Array axes = sch_tmp->GetLoops(block_rf); + ICHECK_GT(axes.size(), num_spatial_loops); + + // Annotate that the rfactor block, which is now the producer of the original block, needs to be + // considered by the rule Random-Compute-Location. + sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Bool(true)); + res.push_back(sch_tmp); + } + + res.push_back(ori_sch); + return res; +} + +TVM_REGISTER_NODE_TYPE(AddRFactorNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") + .set_body_typed(ScheduleRule::AddRFactor); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc new file mode 100644 index 000000000000..a5b1a7abc6d6 --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief The type of inline to be performed on a specific block */ +enum class InlineType : int32_t { + /*! \brief No inline opportunity */ + kNoInline = 0, + /*! \brief Inline the block into its consumer */ + kInlineIntoConsumer = 1, + /*! \brief Inline the block into its producer */ + kInlineIntoProducer = 2, +}; + +/*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */ +class AutoInlineNode : public ScheduleRuleNode { + public: + /*! \brief Checks if the specific block should be inlined */ + inline InlineType CheckInline(const tir::Schedule& sch, const tir::BlockRV& block_rv); + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + InlineType inline_type = CheckInline(sch, block_rv); + if (inline_type == InlineType::kInlineIntoConsumer) { + sch->ComputeInline(block_rv); + } else if (inline_type == InlineType::kInlineIntoProducer) { + sch->ReverseComputeInline(block_rv); + } + return {sch}; + } + + public: + /*! \brief If allows to inline a block into its producer */ + bool into_producer; + /*! \brief If allows to inline a block into its consumer */ + bool into_consumer; + /*! \brief If it only allows to inline into a block generated by cache_read/write */ + bool into_cache_only; + /*! \brief Always inline constant tensors */ + bool inline_const_tensor; + /*! \brief Always disallow if-then-else-like constructs */ + bool disallow_if_then_else; + /*! \brief Always require the read-to-write mapping to be injective to do auto inline */ + bool require_injective; + /*! \brief Always require the read-to-write mapping to be ordered to do auto inline */ + bool require_ordered; + /*! \brief The operators that are disallowed in auto inline */ + Array disallow_op; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("into_producer", &into_producer); + v->Visit("into_consumer", &into_consumer); + v->Visit("into_cache_only", &into_cache_only); + v->Visit("inline_const_tensor", &inline_const_tensor); + v->Visit("disallow_if_then_else", &disallow_if_then_else); + v->Visit("require_injective", &require_injective); + v->Visit("require_ordered", &require_ordered); + v->Visit("disallow_op", &disallow_op); + } + + static constexpr const char* _type_key = "meta_schedule.AutoInline"; + TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); +}; + +inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { + using namespace tvm::tir; + StmtSRef block_sref = sch->GetSRef(block_rv); + ScheduleState state = sch->state(); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + BlockRealize realize = GetBlockRealize(state, block_sref); + // Cond 1. The block has only one write buffer + if (block->writes.size() != 1) { + return InlineType::kNoInline; + } + // Cond 2. The block is a spatial block + if (!IsSpatial(block_sref)) { + return InlineType::kNoInline; + } + // Cond 3. For a block that generates a constant tensor, ignore all other conditions + if (inline_const_tensor && block->reads.empty()) { + return InlineType::kInlineIntoConsumer; + } + // Cond 4. The block doesn't contain any disallowed operators + if (!disallow_op.empty() && HasOp(realize, disallow_op)) { + return InlineType::kNoInline; + } + // Cond 5. The block doesn't have any if-then-else-like constructs + if (disallow_if_then_else && HasIfThenElse(realize)) { + return InlineType::kNoInline; + } + // Cond 6. The mapping from read indices to write indices are injective and ordered + if (require_injective || require_ordered) { + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool injective, ordered; + constexpr auto _ = std::ignore; + std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_, + /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region); + if (require_injective && injective == false) { + return InlineType::kNoInline; + } + if (require_ordered && ordered == false) { + return InlineType::kNoInline; + } + } + } + // Last cond: Check inline into the spatial consumer or the spatial producer + if (into_consumer) { + Array consumer_srefs = GetConsumers(state, block_sref); + if (!consumer_srefs.empty()) { + if (!into_cache_only || + tir::GetAnn(consumer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) { + if (CanComputeInline(state, block_sref)) { + return InlineType::kInlineIntoConsumer; + } + } + } + } + if (into_producer) { + Array producer_srefs = GetProducers(state, block_sref); + if (producer_srefs.size() == 1 && IsSpatial(producer_srefs[0])) { + if (!into_cache_only || + tir::GetAnn(producer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) { + if (CanReverseComputeInline(state, block_sref)) { + return InlineType::kInlineIntoProducer; + } + } + } + } + return InlineType::kNoInline; +} + +ScheduleRule ScheduleRule::AutoInline(bool into_producer, // + bool into_consumer, // + bool into_cache_only, // + bool inline_const_tensor, // + bool disallow_if_then_else, // + bool require_injective, // + bool require_ordered, // + Optional> disallow_op) { + ObjectPtr n = make_object(); + n->into_producer = into_producer; + n->into_consumer = into_consumer; + n->into_cache_only = into_cache_only; + n->inline_const_tensor = inline_const_tensor; + n->disallow_if_then_else = disallow_if_then_else; + n->require_injective = require_injective; + n->require_ordered = require_ordered; + n->disallow_op.clear(); + if (disallow_op.defined()) { + Array op_names = disallow_op.value(); + n->disallow_op.reserve(op_names.size()); + for (const String& op_name : op_names) { + n->disallow_op.push_back(Op::Get(op_name)); + } + } + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(AutoInlineNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") + .set_body_typed(ScheduleRule::AutoInline); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc new file mode 100644 index 000000000000..0c8546ccfcdd --- /dev/null +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class CrossThreadReductionNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + + Optional opt_max_threads_per_block = target->GetAttr("max_threads_per_block"); + Optional opt_warp_size = target->GetAttr("thread_warp_size"); + + if (!opt_max_threads_per_block.defined()) { + LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the " + "rule CrossThreadReduction will not be applied"; + } + if (!opt_warp_size.defined()) { + LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule " + "CrossThreadReduction will not be applied"; + } + max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value; + warp_size = opt_warp_size.value_or(Integer(-1))->value; + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + // Step 0. Check the conditions of this rule. + if (max_threads_per_block == -1 || warp_size == -1) { + return {sch}; + } + const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block, + warp_size)) { + return {sch}; + } + + // Step 1. Make a copy of the original schedule. The new copy is used for scheduling. + tir::Schedule tmp_sch = sch->Copy(); + tmp_sch->Seed(sch->ForkSeed()); + + // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the + // block to its consumers. We want to fuse as much as possible because it results in + // significantly faster schedule. + bool fusible = false; + // `target_loop` is the loop position where the input block will be computed at. + tir::LoopRV target_loop{nullptr}; + // `target_block` is the consumer block that we want to compute-at the input block to. + tir::BlockRV target_block{nullptr}; + // `tgt_block_innermost_loop` is the innermost loop outside the target block. + tir::LoopRV tgt_block_innermost_loop{nullptr}; + + std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) = + GetComputeTargetLoopAndBlock(tmp_sch, block_rv); + + // Step 3. Try block fusion. + int n_candidate = static_cast(thread_extents.size()); + Array probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); + tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); + if (fusible) { + ICHECK(target_block.defined()); + ICHECK(target_loop.defined()); + + // Step 3.1. + // - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first + // bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split + // the loop before binding. + // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. + if (!InThreadScope(tmp_sch, target_block)) { + const Array& split_res = + tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent}); + tmp_sch->Bind(split_res[1], "threadIdx.x"); + if (tgt_block_innermost_loop.same_as(target_loop)) { + target_loop = split_res[0]; + } + } else { + thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value()); + } + // Step 3.2. Do the compute-at. + tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true); + // Step 3.3. Set the storage scope of the output buffer to shared memory. + tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared"); + } + + // Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering, + // fuse all the reduction loops. + size_t num_spatial_loops; + tir::LoopRV fused_reduce_loop; + ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); + // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. + const Array& split_res = + tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent}); + tmp_sch->Bind(split_res[1], "threadIdx.x"); + + return {tmp_sch, sch}; + } + + private: + /*! + * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is + * bound to threadIdx. + * \param sch The TensorIR schedule + * \param block The block to be checked + * \return A boolean indicating whether the block is in thread scope. + */ + bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { + const Array& axes = sch->GetLoops(block); + for (const tir::LoopRV& loop_rv : axes) { + const tir::For& loop = sch->Get(loop_rv); + runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); + if (tir::IsThreadIdx(thread_scope)) { + return true; + } + } + return false; + } + + /*! + * \brief Get the ExprRV which used to define the extent of a given loop. + * \param trace The trace of the schedule, where the extent is to be found + * \param loop The loop whose extent is to be found + * \param extent The finding result + * \return Whether the find is successful. + */ + bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop, + tir::ExprRV* extent) { + for (const tir::Instruction& inst : trace->insts) { + if (inst->kind->name == "Split") { + int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin(); + CHECK(inst->inputs[1 + i].defined()) + << "ValueError: Extracting an extent which needs inference is not supported so far"; + *extent = Downcast(inst->inputs[1 + i]); + return true; + } + } + return false; + } + + /*! + * \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace. + * \param trace The trace of the schedule, where the extent is to be found + * \return The extent of "threadIdx.x" in the input schedule + */ + tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { + tir::ExprRV extent{nullptr}; + for (const tir::Instruction& inst : trace->insts) { + if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { + if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { + return extent; + } + } + } + CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\""; + throw; + } + + /*! + * \brief Get the compute-at target loop and the first block under the target loop. + * \param sch The TensorIR schedule + * \param block_rv The block whose compute-at target loop is queried + * \return A tuple consisting of + * 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible); + * 2. the compute-at target loop when fusible, or a null loop random variable; + * 3. the first block under the target loop when fusible, or a null block random variable; + * 4. the innermost loop outside the target block when fusible, or a null block random variable. + */ + std::tuple GetComputeTargetLoopAndBlock( + const tir::Schedule& sch, const tir::BlockRV& block_rv) { + // Step 1. Get all the consumers of the input block. + Array consumers = sch->GetConsumers(block_rv); + + // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is + // not fusible. + if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } + + // Step 3. Calculate the lowest common ancestor of all the consumers. + // - If the lowest common ancestor is a block: + // - if there is only one consumer, the target block is that consumer; + // - if there are multiple consumers, they must not share a common loop, and the case is not + // fusible; + // - If the lowest common ancestor is a loop, the target block is also the first consumer. + const tir::StmtSRef& lca_sref = + tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); + if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } + + // Step 4. Get the outer loops of the target block, and get the compute-at position index. + Array tgt_block_loops = sch->GetLoops(consumers[0]); + int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); + + // Step 5. A negative position index means not fusible, and vice-versa. + if (pos < 0) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } else { + return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); + } + } + + /*! + * \brief Get the compute-at position index of the input block, according to + * 1. the loops outside the input block; + * 2. the loops outside the target block; + * 3. the lowest common ancestor of all the consumers of the input block. + * \param sch The TensorIR schedule + * \param block_loops The loops outside the input block + * \param tgt_block_loops The loops outside the target block + * \param lca_sref The lowest common ancestor of all the consumers of the input block + * \return The compute-at position index of the input block + */ + int GetComputePosition(const tir::Schedule& sch, const Array& block_loops, + const Array& tgt_block_loops, const tir::StmtSRef& lca_sref) { + int n_block_loop = static_cast(block_loops.size()); + int n_tgt_block_loop = static_cast(tgt_block_loops.size()); + + for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { + if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { + return i - 1; + } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { + // If the lowest common ancestor is a loop, the compute location of the input block should + // not be deeper than the LCA loop. + return i; + } + } + return std::min(n_block_loop, n_tgt_block_loop) - 1; + } + + public: + /*! \brief The maximum number of threads allowed in a thread block */ + int max_threads_per_block; + /*! \brief The number of threads per warp */ + int warp_size; + /*! \brief Candidates of thread axis extent (values are required to be positive). */ + Array thread_extents; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("warp_size", &warp_size); + v->Visit("thread_extents", &thread_extents); + } + + static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction"; + TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { + for (const Integer& extent : thread_extents) { + CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; + } + ObjectPtr n = make_object(); + n->thread_extents = std::move(thread_extents); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") + .set_body_typed(ScheduleRule::CrossThreadReduction); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc new file mode 100644 index 000000000000..e0438a2eb1ed --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -0,0 +1,626 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { +/*! + * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction + * buffers' dimensions as -1 + * \param block_sref The block to be processed + * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers + * \note The method is not designed for generic analysis and relies on assumptions in the scenario + * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header + */ +std::vector GetReadBufferNDims(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + int n = block->reads.size(); + std::vector results(n, -1); + for (int i = 0; i < n; ++i) { + const BufferNode* read_buffer = block->reads[i]->buffer.get(); + if (read_buffer != write_buffer) { + results[i] = read_buffer->shape.size(); + } + } + return results; +} + +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name) { + Optional opt_tensorize_info = GetTensorizeLoopMapping( + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->description); + if (!opt_tensorize_info) return NullOpt; + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + // Construct a mapping from tir loops back to LoopRVs + Map loop2rv; + { + Array loop_rvs = sch->GetLoops(block_rv); + for (const LoopRV& loop_rv : loop_rvs) { + loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); + } + } + // Split the loops + arith::Analyzer analyzer; + std::unordered_set inner_loops; + std::vector reorder_suffix; + reorder_suffix.resize(info->loop_map.size()); + for (const auto& kv : info->loop_map) { + // Extract mapping (block_loop => desc_loop) + const tir::StmtSRef& block_loop_sref = kv.first; + const tir::ForNode* block_loop = block_loop_sref->StmtAs(); + const tir::ForNode* desc_loop = kv.second.get(); + ICHECK(block_loop != nullptr && desc_loop != nullptr); + // Extract the loop extent + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + const auto* int_block_extent = block_extent.as(); + const auto* int_desc_extent = desc_extent.as(); + ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + // Check divisibility + int64_t total = int_block_extent->value; + int64_t inner = int_desc_extent->value; + ICHECK_EQ(total % inner, 0); + int64_t outer = int_block_extent->value / int_desc_extent->value; + // Do the split + Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + ICHECK_EQ(split.size(), 2); + inner_loops.insert(sch->GetSRef(split[1]).operator->()); + // The inner split will be reordered to the loop domain that is tensorized + int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)); + reorder_suffix[desc_loop_index] = split[1]; + } + // Reorder the loops + std::vector reorder_list; + bool meet = false; + Array all_loops = sch->GetLoops(block_rv); + for (const LoopRV& loop : all_loops) { + if (inner_loops.count(sch->GetSRef(loop).operator->())) { + meet = true; + } else if (meet) { + reorder_list.push_back(loop); + } + } + reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); + sch->Reorder(reorder_list); + ICHECK(!reorder_suffix.empty()); + return reorder_suffix[0]; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::ExprRV; +using tir::IterVarType; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Configuration of data reuse type: + * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. + * 1) kMayReuse: reuse is allowed, but no reuse is explored. + * 2) kMustReuse: reuse is allowed and no reuse is not explored. + */ +enum class ReuseType : int32_t { + kNoReuse = 0, + kMayReuse = 1, + kMustReuse = 2, +}; + +/*! + * \brief Converts a string to ReuseType. + * \param str The string to be converted. + * \return The converted ReuseType. + */ +ReuseType Str2ReuseType(const String& str) { + if (str == "no") { + return ReuseType::kNoReuse; + } else if (str == "may") { + return ReuseType::kMayReuse; + } else if (str == "must") { + return ReuseType::kMustReuse; + } else { + LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + throw; + } +} + +/*! \brief Configuration of data reuse patterns */ +struct ReuseConfig { + /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ + ReuseType req; + /*! \brief Which levels are caching stage inserted at */ + std::vector levels; + /*! \brief The storage scope */ + String scope; + + /*! \brief Default constructor: no data reuse */ + ReuseConfig() : req(ReuseType::kNoReuse) {} + + /*! \brief Construct from a configuration dictionary */ + explicit ReuseConfig(const Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { + ICHECK_EQ(config.size(), 3); + } +}; + +/*! \brief The state of auto scheduling for the multi-level tiling rule */ +struct State { + /*! \brief The schedule to date */ + Schedule sch; + /*! \brief The block to be tiled */ + BlockRV block_rv; + /*! \brief The write cache */ + Optional write_cache; + /*! \brief Indicating if the write cache is generated by cache_write */ + bool write_cache_is_added; + /*! \brief The loop tiles */ + Array> tiles; + /*! \brief Whether Tensor Core is used for the inner computation */ + bool tensor_core_is_used; + /*! \brief The Tensor Core cache read block A for Tensor Core computation */ + Optional tensor_core_load_A; + /*! \brief The Tensor Core cache read block B for Tensor Core computation */ + Optional tensor_core_load_B; + /*! \brief The Tensor Core cache write block for Tensor Core computation */ + Optional tensor_core_store; + + /*! \brief Default constructor */ + explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, + bool write_cache_is_added = false, Array> tiles = {}, + bool tensor_core_is_used = false) + : sch(sch), + block_rv(block_rv), + write_cache(write_cache), + write_cache_is_added(write_cache_is_added), + tiles(tiles), + tensor_core_is_used(tensor_core_is_used) {} +}; + +/*! + * \brief Helper to apply a sub-rule to a list of auto scheduling states + * \tparam FLambda The type of the sub-rule functor + * \param states The list of states to be applied + * \return The list of states after applying the sub-rule + */ +template +std::vector SubRule(std::vector states, FLambda sub_rule) { + std::vector results; + for (auto&& state : states) { + std::vector next = sub_rule(std::move(state)); + results.insert(results.end(), // + std::make_move_iterator(next.begin()), // + std::make_move_iterator(next.end())); + } + return results; +} + +/*! + * \brief The mega rule: multi-level tiling with data reuse + */ +class MultiLevelTilingNode : public ScheduleRuleNode { + public: + // SubRule 0. detect compute intrin + inline std::vector DetectTensorCore(State state) const; + // SubRule 1. add write cache + inline std::vector AddWriteReuse(State state) const; + // SubRule 2. tile the loop nest + inline std::vector TileLoopNest(State state) const; + // SubRule 3. add read cache + inline std::vector AddReadReuse(State state) const; + // SubRule 4. fuse write cache + inline std::vector FuseWriteReuse(State state) const; + + State TensorCoreLoad(State state) const { + // Add the cache read stage for Tensor Core + state.tensor_core_load_A = state.sch->CacheRead(state.block_rv, 1, "wmma.matrix_a"); + state.tensor_core_load_B = state.sch->CacheRead(state.block_rv, 2, "wmma.matrix_b"); + const Array& r_tiles = state.tiles[r_indices_.back()]; + // Insert cache_read block to the proper place + ICHECK(!r_tiles.empty()) << "ValueError: Cannot find any reduction loop in the block"; + state.sch->ComputeAt(state.tensor_core_load_A.value(), r_tiles.back(), true); + state.sch->ComputeAt(state.tensor_core_load_B.value(), r_tiles.back(), true); + // Annotate the block + state.sch->Annotate(state.tensor_core_load_A.value(), tir::attr::meta_schedule_auto_tensorize, + String("wmma_load_a")); + state.sch->Annotate(state.tensor_core_load_B.value(), tir::attr::meta_schedule_auto_tensorize, + String("wmma_load_b")); + return state; + } + + State TensorCoreStore(State state) const { + // Add the cache read stage for Tensor Core + state.tensor_core_store = state.sch->CacheWrite(state.block_rv, 0, "wmma.accumulator"); + // Annotate the block + state.sch->Annotate(state.tensor_core_store.value(), tir::attr::meta_schedule_auto_tensorize, + String("wmma_store")); + return state; + } + + State TensorCoreStoreFusion(State state, int level) const { + const LoopRV& loop = state.tiles[level].back(); + state.sch->ReverseComputeAt(state.tensor_core_store.value(), loop, true); + return state; + } + + BlockRV GetRootBlockRV(const Schedule& sch, BlockRV block_rv) const { + const tir::StmtSRefNode* block = sch->GetSRef(block_rv).get(); + for (; block->parent != nullptr; block = block->parent) + ; + for (const auto& kv : sch->mod()->functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + const tir::BlockNode* root = func->body.as()->block.get(); + if (root == block->StmtAs()) { + BlockRV root_rv = sch->GetBlock(root->name_hint, gv->name_hint); + return root_rv; + } + } + } + ICHECK(false) << "Ill schedule data structure"; + throw; + } + + // Do nothing; Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + this->max_threads_per_block_ = v.value()->value; + if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + this->thread_warp_size_ = v.value()->value; + } else { + LOG(INFO) << "'thread_warp_size' is not defined in the target"; + } + } + } + + // Entry of the mega rule; Inherited from ScheduleRuleNode + Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); + + std::vector states{State(sch, block_rv)}; + states = SubRule(std::move(states), [&](State state) { return DetectTensorCore(state); }); + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return FuseWriteReuse(state); }); + Array results; + for (auto&& state : states) { + results.push_back(std::move(state.sch)); + } + return results; + } + + public: + /*! + * \brief The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + */ + String structure; + /*! \brief For each level of tiles, which thread axis it is bound to */ + Array tile_binds; + /*! \brief Whether to use Tensor Core */ + bool use_tensor_core; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The length of vector lane in vectorized cooperative fetching */ + std::vector vector_load_lens; + /*! \brief Data reuse configuration for reading */ + ReuseConfig reuse_read_; + /*! \brief Data reuse configuration for writing */ + ReuseConfig reuse_write_; + /*! \brief The indices of spatial tiles in `structure` */ + std::vector s_indices_; + /*! \brief The indices of reduction tiles in `structure` */ + std::vector r_indices_; + /*! \brief The size of the thread warp */ + int thread_warp_size_; + /*! \brief The maximum number of threads to be used size of a thread warp */ + int max_threads_per_block_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("structure", &structure); + v->Visit("tile_binds", &tile_binds); + v->Visit("use_tensor_core", &use_tensor_core); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `vector_load_lens` is not visited + // `reuse_read_` is not visited + // `reuse_write_` is not visited + // `s_indices_` is not visited + // `r_indices_` is not visited + // `thread_warp_size_` is not visited + // `max_threads_per_block` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); +}; + +inline std::vector MultiLevelTilingNode::DetectTensorCore(State state) const { + std::vector result; + // If Tensor Core is not allowed, we skip this subrule + if (!use_tensor_core) return {state}; + // Do tiling to match Tensor Core wmma sync intrin + BlockRV block_rv = state.block_rv; + Optional tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, "wmma_sync"); + if (!tiled_loop_rv.defined()) return {state}; + // Do blockize + state.block_rv = state.sch->Blockize(tiled_loop_rv.value()); + // Annotate the block + state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String("wmma_sync")); + state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String("wmma_fill")); + state.tensor_core_is_used = true; + // Annotate the root block to notify the following postprocessors + state.sch->Annotate(GetRootBlockRV(state.sch, state.block_rv), + tir::attr::meta_schedule_tensor_core_enabled, String("1")); + result.push_back(state); + return result; +} + +inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + if (state.tensor_core_is_used) state = TensorCoreStore(state); + return {std::move(state)}; + } + // Case 1. If the write cache is already there, we don't need to add another. + if (config.req == ReuseType::kMayReuse) { + Array consumer_rvs = state.sch->GetConsumers(state.block_rv); + if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { + state.write_cache = consumer_rvs[0]; + state.write_cache_is_added = false; + if (state.tensor_core_is_used) state = TensorCoreStore(state); + return {std::move(state)}; + } + } + std::vector results; + results.reserve(2); + // Case 2. No write cache is added + if (config.req == ReuseType::kMayReuse) { + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv, + /*write_cache=*/NullOpt, + /*write_cache_is_added=*/false); + new_state.sch->Seed(state.sch->ForkSeed()); + if (new_state.tensor_core_is_used) new_state = TensorCoreStore(new_state); + results.emplace_back(std::move(new_state)); + } + // Case 3. Add one write cache + BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); + state.write_cache = write_cache; + { + tir::Annotate(state.sch->state(), state.sch->GetSRef(write_cache), // + tir::attr::meta_schedule_cache_type, // + Integer(tir::attr::meta_schedule_cache_type_write)); + } + state.write_cache_is_added = true; + if (state.tensor_core_is_used) state = TensorCoreStore(state); + results.emplace_back(std::move(state)); + return results; +} + +inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const { + Schedule& sch = state.sch; + const BlockRV& block_rv = state.block_rv; + // Step 1. Assuming trivial binding, pair the loops and their iter-var-types + Array loops = sch->GetLoops(block_rv); + std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv)); + ICHECK_EQ(loops.size(), iter_types.size()); + // Step 2. For each loop axis, tile it + int64_t spatial_loop_product = 1; + std::vector> tiles(s_indices_.size() + r_indices_.size()); + for (int i = 0, n = loops.size(); i < n; ++i) { + LoopRV loop = loops[i]; + const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { + idx = &s_indices_; + if (spatial_loop_product != -1) { + if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + spatial_loop_product *= *extent; + } else { + spatial_loop_product = -1; + } + } + } else if (iter_types[i] == IterVarType::kCommReduce) { + idx = &r_indices_; + } else { + continue; + } + // Do the split + int n_tiles = idx->size(); + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } + } + // Step 3. Reorder to organize the tiles + sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); + // Step 4. Bind the tiles to threads + int n_binds = std::min(tile_binds.size(), tiles.size()); + for (int i = 0; i < n_binds; ++i) { + LoopRV fused = sch->Fuse(tiles[i]); + sch->Bind(fused, tile_binds[i]); + tiles[i] = {fused}; + } + state.tiles = Array>{tiles.begin(), tiles.end()}; + if (this->thread_warp_size_ != -1) { + int64_t low_inclusive = 1; + int64_t high_inclusive = this->max_threads_per_block_; + if (spatial_loop_product > 2 * this->thread_warp_size_) { + low_inclusive = this->thread_warp_size_; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive, + Integer(low_inclusive)); + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive, + Integer(high_inclusive)); + } + return {state}; +} + +inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { + const ReuseConfig& config = this->reuse_read_; + if (config.req == ReuseType::kNoReuse) { + if (state.tensor_core_is_used) state = TensorCoreLoad(state); + return {std::move(state)}; + } + ICHECK(config.req != ReuseType::kMayReuse); + const BlockRV& block_rv = state.block_rv; + std::vector results; + results.reserve(config.levels.size()); + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + // Enumerate all buffers that are read but not written + std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { + int buffer_ndim = read_buffer_ndims[i]; + if (buffer_ndim == -1) { + continue; + } + // Do cache_read + BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope); + { + tir::Annotate(sch->state(), sch->GetSRef(cache_read_block), // + tir::attr::meta_schedule_cache_type, + Integer(tir::attr::meta_schedule_cache_type_read)); + } + // Insert cache_read block to the proper place + sch->ComputeAt(cache_read_block, loop_rv, true); + // Fuse the iterators of the cache_read + Array buffer_loops = sch->GetLoops(cache_read_block); + LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); + // Annotate cooperative fetching + if (!vector_load_lens.empty()) { + int n = vector_load_lens.size(); + double prob = 1.0 / n; + ExprRV vector_load_len = + sch->SampleCategorical(support::AsArray(vector_load_lens), + Array(n, FloatImm(DataType::Float(64), prob))); + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, + vector_load_len); + } + } + State new_state = state; + new_state.sch = sch; + if (new_state.tensor_core_is_used) new_state = TensorCoreLoad(new_state); + results.push_back(std::move(new_state)); + } + return results; +} + +inline std::vector MultiLevelTilingNode::FuseWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + if (state.tensor_core_is_used) state = TensorCoreStoreFusion(state, r_indices_.front() - 1); + return {std::move(state)}; + } + // If the only-consumer does not exist, or is not elementwise, then do not do fusion + if (!state.write_cache.defined()) { + if (state.tensor_core_is_used) state = TensorCoreStoreFusion(state, r_indices_.front() - 1); + return {std::move(state)}; + } + std::vector results; + // Special case. + // Stages added by `cache_write` must be fused at some level, otherwise it has no benefit. + // On the other hand, If the consumer stage is not added by `cache_write`, + // we may choose not to fuse by setting `must_cache_write = False` + if (!state.write_cache_is_added && config.req != ReuseType::kMustReuse) { + results.push_back(state); + } + BlockRV consumer = state.write_cache.value(); + // Enumerate the level of tile to be fused at + for (int level : config.levels) { + State new_state = state; + new_state.sch = state.sch->Copy(); + new_state.sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + if (new_state.tensor_core_is_used) new_state = TensorCoreStoreFusion(new_state, level - 1); + new_state.sch->ReverseComputeAt(consumer, loop_rv, true); + results.push_back(std::move(new_state)); + } + return results; +} + +// Constructor + +ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, + bool use_tensor_core, + Optional max_innermost_factor, + Optional> vector_load_lens, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->use_tensor_core = use_tensor_core; + if (use_tensor_core) { + // Check whether corresponding wmma intrinsics are registered + tir::TensorIntrin::Get("wmma_sync"); + tir::TensorIntrin::Get("wmma_load_a"); + tir::TensorIntrin::Get("wmma_load_b"); + tir::TensorIntrin::Get("wmma_store"); + tir::TensorIntrin::Get("wmma_fill"); + } + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector(vector_load_lens.value()) + : std::vector(); + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + n->thread_warp_size_ = -1; + n->max_threads_per_block_ = -1; + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling") + .set_body_typed(ScheduleRule::MultiLevelTiling); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc new file mode 100644 index 000000000000..b7100be9254f --- /dev/null +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +bool IsRootWithNoAnnotation(const Schedule& sch, const BlockRV& block_rv) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (block_sref->parent != nullptr) { + return false; + } + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return block->annotations.empty(); +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + if (this->max_jobs_per_core != -1) { + Target target = context->target.value(); + this->max_parallel_extent_ = GetTargetNumCores(target) * max_jobs_per_core; + } + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& root_rv) { + if (!tir::IsRootWithNoAnnotation(sch, root_rv)) { + return {sch}; + } + // Parallelization + if (max_jobs_per_core != -1) { + sch->Annotate(root_rv, tir::attr::meta_schedule_parallel, + Integer(this->max_parallel_extent_)); + } + // Vectorization + if (max_vectorize_extent != -1) { + sch->Annotate(root_rv, tir::attr::meta_schedule_vectorize, Integer(max_vectorize_extent)); + } + // Unroll + if (!unroll_max_steps.empty()) { + int n = unroll_max_steps.size(); + double prob = 1.0 / n; + Array probs(n, FloatImm(DataType::Float(64), prob)); + PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); + if (unroll_explicit) { + sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); + } else { + sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_implicit, max_step); + } + } + return {sch}; + } + + public: + /*! + * \brief The maximum number of jobs to be launched per CPU core. + * It sets the uplimit of CPU parallelism, i.e. `num_cores * max_jobs_per_core`. + * Use -1 to disable parallelism. + */ + int64_t max_jobs_per_core; + /*! + * \brief The maximum extent to be vectorized. It sets the uplimit of the CPU vectorization. + * Use -1 to disable vectorization. + */ + int max_vectorize_extent; + /*! + * \brief brief description The maximum number of unroll steps to be done. + * Use an empty array to disable unroll. + */ + Array unroll_max_steps; + /*! \brief Whether to explicitly unroll the loop, or just add a unroll pragma. */ + bool unroll_explicit; + /*! \brief The number of cores in CPU. */ + int64_t max_parallel_extent_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_jobs_per_core", &max_jobs_per_core); + v->Visit("max_vectorize_extent", &max_vectorize_extent); + v->Visit("unroll_max_steps", &unroll_max_steps); + v->Visit("unroll_explicit", &unroll_explicit); + // `max_parallel_extent_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, + int max_vectorize_extent, + Array unroll_max_steps, + bool unroll_explicit) { + ObjectPtr n = make_object(); + n->max_jobs_per_core = max_jobs_per_core; + n->max_vectorize_extent = max_vectorize_extent; + n->unroll_max_steps = unroll_max_steps; + n->unroll_explicit = unroll_explicit; + n->max_parallel_extent_ = -1; + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") + .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc new file mode 100644 index 000000000000..ba1476719491 --- /dev/null +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class RandomComputeLocationNode : public ScheduleRuleNode { + public: + bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const { + const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + + // Cond 1. The block is not the root block. + if (block_sref->parent == nullptr) { + return false; + } + // Cond 2. The block should be the direct child block of the root block. + if (GetScopeRoot(sch->state(), block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false) + ->parent != nullptr) { + return false; + } + // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child + // block. + Array loop_srefs = tir::GetLoops(block_sref); + if (loop_srefs.empty()) { + return false; + } + if (tir::GetChildBlockSRefOnSRefTree(sch->state(), loop_srefs[0]).size() > 1) { + return false; + } + // Cond 5. The block is not tiled. + if (tir::HasBeenMultiLevelTiled(block_sref)) { + return false; + } + // Cond 6. The block has at lease one consumer. + if (tir::GetConsumers(sch->state(), sch->GetSRef(block_rv)).empty()) { + return false; + } + + return true; + } + + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final {} + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + if (!CheckConditions(sch, block_rv)) { + return {sch}; + } + + // Step 1. If the producer of the input block needs a random compute-at location (specified by + // the annotation), we colect the producer first, and transform the producer block later. + // - The reason we collect the producer before transforming the input block is that, if the + // decision of Sample-Compute-Location is "compute-inline" for the input block, we can no longer + // access the input block. Hence we collect its producer ahead of time. + // - Note that only single producer is allowed in this case. + Array producers{nullptr}; + if (tir::HasAnn(sch->GetSRef(block_rv), tir::attr::meta_schedule_random_compute_producer, + true)) { + producers = sch->GetProducers(block_rv); + sch->Unannotate(block_rv, tir::attr::meta_schedule_random_compute_producer); + ICHECK_EQ(producers.size(), 1); + } + + // Step 2. Transform the input block. + tir::Schedule res = RandomlyComputeAt(sch, block_rv); + + // Step 3. Transform the producer block if compute-location sampling is needed. + if (producers.defined()) { + res = RandomlyComputeAt(res, producers[0]); + } + + return {res}; + } + + /*! + * \brief Keep sampling a compute-at location for the input block until success. + * \param sch The TIR schedule + * \param block_rv The block whose compute-at location is to be sampled + * \return The TIR schedule after transformation + */ + tir::Schedule RandomlyComputeAt(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + for (;;) { + tir::LoopRV compute_at_loc = sch->SampleComputeLocation(block_rv); + try { + sch->ComputeAt(block_rv, compute_at_loc, true); + } catch (const dmlc::Error& e) { + // ComputeAt fails, cleanup the following before re-try: + // 1) trace: instruction & decisions + // 2) sym_tab + sch->trace().value()->Pop(); + sch->RemoveRV(compute_at_loc); + continue; + } + break; + } + return sch; + } + + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; + TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::RandomComputeLocation() { + ObjectPtr n = make_object(); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") + .set_body_typed(ScheduleRule::RandomComputeLocation); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc new file mode 100644 index 000000000000..4c6c16a1fd10 --- /dev/null +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -0,0 +1,675 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" + +#define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ + CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ + << "but get `" << #p << " = " << (p) << '\''; + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/**************** Data Structure ****************/ + +/*! + * \brief A heap with a size up-limit. If overflow happens, it evicted the worst items. + * \note It maintains a min heap in terms of `Item::score`. Therefore, when + * overflow happens, the element evicted is the one with the min `Item::score`. + * As time goes, the elements in the heap are going to be larger. + */ +class SizedHeap { + public: + struct Item { + Schedule sch; + IRModule mod; + size_t shash; + double score; + bool operator<(const Item& other) const { return score > other.score; } + }; + + struct ItemHash { + size_t operator()(const Item& hash) const { return hash.shash; } + }; + + struct ItemEqual { + bool operator()(const Item& lhs, const Item& rhs) const { + return lhs.shash == rhs.shash && StructuralEqual()(lhs.mod, rhs.mod); + } + }; + /*! + * \brief Constructor + * \param size_limit The up-limit of the heap size + */ + explicit SizedHeap(int size_limit) : size_limit(size_limit) { heap.reserve(size_limit); } + + /*! + * \brief Push the specific item to the heap if its key did not appears in the heap + * \param item The item to be pushed + */ + void Push(Schedule sch, IRModule mod, double score) { + Item item{sch, mod, StructuralHash()(mod), score}; + if (!in_heap.insert(item).second) { + return; + } + int size = heap.size(); + if (size < size_limit) { + // Heap is not full, just push + heap.emplace_back(item); + std::push_heap(heap.begin(), heap.end()); + } else if (item.score > heap.front().score) { + // if the item is better than the worst one in the heap, we can safely kick it out + std::pop_heap(heap.begin(), heap.end()); + heap.back() = item; + std::push_heap(heap.begin(), heap.end()); + } + // Otherwise, the item is worse than any other element in the heap + } + + /*! \brief Up-limit of the heap size */ + int size_limit; + /*! \brief The heap, the worse the topper */ + std::vector heap; + /*! \brief The traces that are in the heap */ + std::unordered_set in_heap; +}; + +struct PerThreadData { + IRModule mod{nullptr}; + TRandState rand_state{-1}; + std::function trace_sampler = nullptr; + std::function()> mutator_sampler = nullptr; + + /*! + * \brief Set the value for the trace and mutator samplers per thread. + * \param scores The predicted score for the given samples. + * \param genetic_mutate_prob The probability of mutation. + * \param mutator_probs The probability of each mutator as a dict. + */ + void Set(const std::vector& scores, double genetic_mutate_prob, + const Map& mutator_probs) { + trace_sampler = tir::MakeMultinomialSampler(&rand_state, scores); + mutator_sampler = MakeMutatorSampler(genetic_mutate_prob, mutator_probs, &rand_state); + } + + private: + /*! + * \brief Create a sampler function that picks mutators according to the mass function + * \param rand_state The random state for sampling + * \return The sampler created + */ + static std::function()> MakeMutatorSampler( + double genetic_mutate_prob, // + const Map& mutator_probs, // + TRandState* rand_state) { + std::vector> mutators; + std::vector masses; + mutators.push_back(NullOpt); + masses.push_back(1.0 - genetic_mutate_prob); + double total_mass_mutator = 0.0; + if (genetic_mutate_prob > 0) { + for (const auto& kv : mutator_probs) { + Mutator mutator = kv.first; + double mass = kv.second->value; + total_mass_mutator += mass; + mutators.push_back(mutator); + masses.push_back(mass * genetic_mutate_prob); + } + } + // Normalize the sum to 1.0 + if (total_mass_mutator == 0.0) { + masses[0] = 1.0; + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] = 0.0; + } + } else if (total_mass_mutator != 1.0) { + for (int i = 1, n = masses.size(); i < n; ++i) { + masses[i] /= total_mass_mutator; + } + } + return [idx_sampler = tir::MakeMultinomialSampler(rand_state, masses), + mutators = std::move(mutators)]() -> Optional { + int i = idx_sampler(); + return mutators[i]; + }; + } +}; + +struct ConcurrentBitmask { + /*! The bit width. */ + static constexpr const int kBitWidth = 64; + /*! \brief The size of the concurrent bitmask. */ + int size; + /*! \brief The bitmasks. */ + std::vector bitmask; + /*! \brief The mutexes, one per kBitWidth(64 here) bitmasks. */ + std::vector mutexes; + + /*! + * \brief Constructor + * \param n The total slots managed by the concurrent bitmask. + */ + explicit ConcurrentBitmask(int n) + : size((n + kBitWidth - 1) / kBitWidth), bitmask(size, 0), mutexes(size) {} + /*! + * \brief Query and mark the given index if not visited before. + * \param x The index to concurrently check if used. If not, mark as used. + * \return Whether the index has been used before. + */ + bool QueryAndMark(int x) { + constexpr uint64_t one = 1; + std::unique_lock lock(mutexes[x / kBitWidth]); + if (bitmask[x / kBitWidth] & (one << (x % kBitWidth))) { + return false; + } else { + bitmask[x / kBitWidth] |= one << (x % kBitWidth); + return true; + } + } +}; + +/**************** Util Functions ****************/ + +/*! + * \brief Assemble measure candidates from the given candidate traces. + * \param traces The picked candidate traces. + * \return The assembled measure candidates. + */ +Array AssembleCandidates(const std::vector& picks, + const Array& args_info) { + Array measure_inputs; + measure_inputs.reserve(picks.size()); + for (const Schedule& sch : picks) { + measure_inputs.push_back(MeasureCandidate(sch, args_info)); + } + return measure_inputs; +} + +/*! + * \brief Predict the normalized score of each candidate. + * \param candidates The candidates for prediction + * \param task The search task + * \param space The search space + * \return The normalized score in the prediction + */ +std::vector PredictNormalizedScore(const std::vector& candidates, + const TuneContext& tune_context, + const CostModel& cost_model, + const Array& args_info) { + ICHECK(!candidates.empty()) << "Candidates given for score prediction can not be empty list!"; + std::vector scores = + cost_model->Predict(tune_context, AssembleCandidates(candidates, args_info)); + for (double& score : scores) { + score = std::max(0.0, score); + } + return scores; +} + +/**************** Evolutionary Search ****************/ + +/*!\brief A search strategy that generates measure candidates using evolutionary search. */ +class EvolutionarySearchNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + EvolutionarySearchNode* self; + /*! \brief The design spaces. Decisions are not used so traces only. */ + Array design_spaces; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(EvolutionarySearchNode* self, Array design_spaces) + : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} + + /*! + * \brief Pick up best candidates from database. + * \param num The number of traces to produce. + * \return The picked best candidates. + */ + inline std::vector PickBestFromDatabase(int num); + /*! + * \brief Sample the initial population from previous measured results and randomly generated + * traces via trace replaying. + * \param num The number of traces to produce. + * \return The initial population of traces sampled. + */ + inline std::vector SampleInitPopulation(int num); + /*! + * \brief Evolve the initial population using mutators and samplers. + * \param population The initial population of traces sampled. + * \param num The number of traces to produce. + * \return The evolved traces from initial population. + */ + inline std::vector EvolveWithCostModel(std::vector population, int num); + /*! + * \brief Pick final candidates from the given initial population and bests of evolved ones. + * \param inits The initial population of traces sampled. + * \param bests The best candidates predicted from evolved traces. + * \param num The number of traces to produce. + * \return The final picked candidates with a ratio of both. + */ + inline std::vector PickWithEpsGreedy(const std::vector& inits, + const std::vector& bests, int num); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline Optional> GenerateMeasureCandidates(); + /*! \brief An interface method to be called by it's counterpart in EvolutionarySearchNode */ + inline void NotifyRunnerResults(const TuneContext& tune_context, + const Array& measure_candidates, + const Array& results); + }; + + /*! \brief The tuning context of the evolutionary search strategy. */ + const TuneContextNode* tune_context_{nullptr}; + /*! \brief The target for the workload. */ + Target target_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief A Database for selecting useful candidates. */ + Database database_{nullptr}; + /*! \brief A cost model helping to explore the search space */ + CostModel cost_model_{nullptr}; + /*! \brief The postprocessors. */ + Array postprocs_{nullptr}; + /*! \brief Mutators and their probability mass */ + Map mutator_probs_{nullptr}; + /*! \brief The number of threads to use. To be initialized with TuneContext. */ + int num_threads_; + /*! \brief The random state. To be initialized with TuneContext. */ + TRandState rand_state_; + /*! \brief Pre thread data including module to be tuned and random state. */ + std::vector per_thread_data_; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + /*! \brief The token registered for the given workload in database. */ + Workload token_{nullptr}; + + /*** Configuration: global ***/ + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + /*! \brief The population size in the evolutionary search. */ + int population_size; + /*** Configuration: the initial population ***/ + /*! \brief The ratio of measured states used in the initial population */ + double init_measured_ratio; + /*! \brief The minimal size of unmeasured population in the initial sampling.*/ + int init_min_unmeasured; + /*** Configuration: evolution ***/ + /*! \brief The number of iterations performed by generic algorithm. */ + int genetic_num_iters; + /*! \brief The probability to perform mutation */ + double genetic_mutate_prob; + /*! \brief The maximum number to try evolving the given trace. */ + int genetic_max_fail_count; + /*** Configuration: pick states for measurement ***/ + /*! \brief The ratio of measurements to use randomly sampled states. */ + double eps_greedy; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `tune_context_` is not visited + // `target_` is not visited + // `args_info_` is not visited + // `database` is not visited + // `cost_model` is not visited + // `postprocs` is not visited + // `mutator_probs_` is not visited + // `num_threads` is not visited + // `rand_state_` is not visited + // `per_thread_data_` is not visited + // `state_` is not visited + + /*** Configuration: global ***/ + v->Visit("num_trials_total", &num_trials_total); + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("population_size", &population_size); + /*** Configuration: the initial population ***/ + v->Visit("init_measured_ratio", &init_measured_ratio); + v->Visit("init_min_unmeasured", &init_min_unmeasured); + /*** Configuration: evolution ***/ + v->Visit("genetic_num_iters", &genetic_num_iters); + v->Visit("genetic_mutate_prob", &genetic_mutate_prob); + v->Visit("genetic_max_fail_count", &genetic_max_fail_count); + /*** Configuration: pick states for measurement ***/ + v->Visit("eps_greedy", &eps_greedy); + } + + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; + TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + CHECK(tune_context.defined()) << "TuneContext must be defined!"; + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; + CHECK(tune_context->target.defined()) << "Target must be defined!"; + this->tune_context_ = tune_context.get(); + this->target_ = tune_context->target.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->mutator_probs_ = tune_context->mutator_probs; + this->postprocs_ = tune_context->postprocs; + this->num_threads_ = tune_context->num_threads; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->cost_model_ = tune_context->task_scheduler->cost_model.value(); + this->database_ = tune_context->task_scheduler->database; + this->token_ = this->database_->CommitWorkload(tune_context->mod.value()); + this->per_thread_data_.resize(this->num_threads_); + for (const auto& kv : this->mutator_probs_) { + double mass = kv.second->value; + TVM_META_SCHEDULE_CHECK_PROB_RANGE(mass, "mutator_probs"); + } + for (PerThreadData& data : this->per_thread_data_) { + data.mod = DeepCopyIRModule(tune_context->mod.value()); + data.rand_state = ForkSeed(&this->rand_state_); + } + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(!design_spaces.empty()); + ICHECK(this->state_ == nullptr); + // Change to traces + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const TuneContext& tune_context, + const Array& measure_candidates, + const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(tune_context, measure_candidates, results); + } +}; + +std::vector EvolutionarySearchNode::State::PickBestFromDatabase(int num) { + std::vector measured_traces; + measured_traces.reserve(num); + Array top_records = self->database_->GetTopK(self->token_, num); + for (TuningRecord record : top_records) { + measured_traces.push_back(record->trace); + } + int actual_num = measured_traces.size(); + ThreadedTraceApply pp(self->postprocs_); + std::vector results(actual_num, Schedule{nullptr}); + auto f_proc_measured = [this, &measured_traces, &results, &pp](int thread_id, + int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + tir::Trace trace = measured_traces.at(trace_id); + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + } else { + LOG(FATAL) << "ValueError: Cannot postprocess the trace:\n" << trace; + throw; + } + }; + support::parallel_for_dynamic(0, actual_num, self->num_threads_, f_proc_measured); + return results; +} + +std::vector EvolutionarySearchNode::State::SampleInitPopulation(int num) { + ThreadedTraceApply pp(self->postprocs_); + std::vector out_schs; + while (static_cast(out_schs.size()) < self->init_min_unmeasured) { + std::vector results(num, Schedule{nullptr}); + auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void { + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + Schedule& result = results.at(trace_id); + ICHECK(!result.defined()); + int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size()); + tir::Trace trace(design_spaces[design_space_index]->insts, {}); + if (Optional sch = pp.Apply(mod, trace, rand_state)) { + result = sch.value(); + } + }; + support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured); + for (int i = 0; i < num; i++) { + if (results[i].defined()) { + out_schs.push_back(results[i]); + } + } + LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures(); + } + return out_schs; +} + +std::vector EvolutionarySearchNode::State::EvolveWithCostModel( + std::vector population, int num) { + ICHECK_GT(num, 0); + // The heap to record best schedule, we do not consider schedules that are already measured + // Also we use `in_heap` to make sure items in the heap are de-duplicated + SizedHeap heap(num); + for (int iter = 0;; ++iter) { + // Predict normalized score with the cost model, + std::vector scores = + PredictNormalizedScore(population, // + GetRef(self->tune_context_), // + self->cost_model_, // + self->args_info_); + ICHECK_EQ(scores.size(), population.size()); + for (int i = 0, n = population.size(); i < n; ++i) { + Schedule sch = population.at(i); + IRModule mod = sch->mod(); + double score = scores.at(i); + if (!self->database_->HasWorkload(mod)) { + heap.Push(sch, mod, score); + } + } + // Discontinue once it reaches end of search + if (iter == self->genetic_num_iters) { + break; + } + // Set threaded samplers, with probability from predicated normalized throughputs + for (PerThreadData& data : self->per_thread_data_) { + data.Set(scores, self->genetic_mutate_prob, self->mutator_probs_); + } + ThreadedTraceApply pp(self->postprocs_); + ConcurrentBitmask cbmask(self->population_size); + std::vector next_population(self->population_size, Schedule{nullptr}); + // The worker function + auto f_find_candidate = [&cbmask, &population, &next_population, &pp, this](int thread_id, + int trace_id) { + // Prepare samplers + PerThreadData& data = self->per_thread_data_.at(thread_id); + TRandState* rand_state = &data.rand_state; + const IRModule& mod = data.mod; + std::function& trace_sampler = data.trace_sampler; + std::function()>& mutator_sampler = data.mutator_sampler; + Schedule& result = next_population.at(trace_id); + int sampled_trace_id = -1; + // Loop until success + for (int fail_count = 0; fail_count <= self->genetic_max_fail_count; ++fail_count) { + sampled_trace_id = trace_sampler(); + tir::Trace trace = population.at(sampled_trace_id)->trace().value(); + if (Optional opt_mutator = mutator_sampler()) { + // Decision: mutate + Mutator mutator = opt_mutator.value(); + if (Optional new_trace = mutator->Apply(trace, rand_state)) { + if (Optional sch = pp.Apply(mod, new_trace.value(), rand_state)) { + // note that sch's trace is different from new_trace + // because it contains post-processing information + result = sch.value(); + break; + } + } + } else if (cbmask.QueryAndMark(sampled_trace_id)) { + // Decision: do not mutate + break; + } + } + // if retry count exceeds the limit, reuse an old sample + if (!result.defined()) { + result = population.at(sampled_trace_id); + } + }; + support::parallel_for_dynamic(0, self->population_size, self->num_threads_, f_find_candidate); + population.swap(next_population); + LOG(INFO) << "Evolve iter #" << iter << " done. Summary:\n" << pp.SummarizeFailures(); + } + // Return the best states from the heap, sorting from higher score to lower ones + std::sort(heap.heap.begin(), heap.heap.end()); + std::vector results; + results.reserve(num); + for (const SizedHeap::Item& item : heap.heap) { + results.push_back(item.sch); + } + + constexpr int kNumScoresPerLine = 16; + std::ostringstream os; + int n = heap.heap.size(); + for (int st = 0; st < n; st += kNumScoresPerLine) { + os << std::endl; + int ed = std::min(st + kNumScoresPerLine, n); + os << "[" << (st + 1) << " : " << ed << "]:\t"; + for (int i = st; i < ed; ++i) { + if (i != st) { + os << " "; + } + os << std::fixed << std::setprecision(4) << heap.heap.at(i).score; + } + } + LOG(INFO) << "Scores of the best " << n << " candidates:" << os.str(); + return results; +} + +std::vector EvolutionarySearchNode::State::PickWithEpsGreedy( + const std::vector& unmeasured, const std::vector& bests, int num) { + int num_rands = num * self->eps_greedy; + int num_bests = num - num_rands; + std::vector rands = + tir::SampleWithoutReplacement(&self->rand_state_, unmeasured.size(), unmeasured.size()); + std::vector results; + results.reserve(num); + for (int i = 0, i_bests = 0, i_rands = 0; i < num; ++i) { + bool has_best = i_bests < static_cast(bests.size()); + bool has_rand = i_rands < static_cast(rands.size()); + // Pick a schedule + Schedule sch{nullptr}; + // If needs `bests`, then prefer `bests` + if (i < num_bests) { + if (has_best) { + sch = bests[i_bests++]; + } else if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else { + break; + } + } else { + // Else prefer `rands` + if (has_rand) { + sch = unmeasured[rands[i_rands++]]; + } else if (has_best) { + sch = bests[i_bests++]; + } else { + break; + } + } + results.push_back(sch); + } + return results; +} + +Optional> EvolutionarySearchNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + int sample_num = self->num_trials_per_iter; + if (ed > self->num_trials_total) { + sample_num = self->num_trials_total - st; + ed = self->num_trials_total; + } + ICHECK_LT(st, ed); + int pop = self->population_size; + std::vector inits; + inits.reserve(pop); + + LOG(INFO) << "Generating candidates......"; + std::vector measured = PickBestFromDatabase(pop * self->init_measured_ratio); + LOG(INFO) << "Picked top " << measured.size() << " candidate(s) from database"; + std::vector unmeasured = SampleInitPopulation(pop - measured.size()); + LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)"; + inits.insert(inits.end(), measured.begin(), measured.end()); + inits.insert(inits.end(), unmeasured.begin(), unmeasured.end()); + std::vector bests = EvolveWithCostModel(inits, sample_num); + LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search"; + std::vector picks = PickWithEpsGreedy(unmeasured, bests, sample_num); + LOG(INFO) << "Sending " << picks.size() << " candidates(s) for measurement"; + return AssembleCandidates(picks, self->args_info_); +} + +void EvolutionarySearchNode::State::NotifyRunnerResults( + const TuneContext& tune_context, const Array& measure_candidates, + const Array& results) { + st += results.size(); + ed += results.size(); +} + +SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, // + int num_trials_total, // + int population_size, // + double init_measured_ratio, // + int init_min_unmeasured, // + int genetic_num_iters, // + double genetic_mutate_prob, // + int genetic_max_fail_count, // + double eps_greedy) { + TVM_META_SCHEDULE_CHECK_PROB_RANGE(init_measured_ratio, "Initial measured ratio"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(genetic_mutate_prob, "Mutation probability"); + TVM_META_SCHEDULE_CHECK_PROB_RANGE(eps_greedy, "Greedy pick probability"); + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + n->population_size = population_size; + n->init_measured_ratio = init_measured_ratio; + n->init_min_unmeasured = init_min_unmeasured; + n->genetic_num_iters = genetic_num_iters; + n->genetic_max_fail_count = genetic_max_fail_count; + n->genetic_mutate_prob = genetic_mutate_prob; + n->eps_greedy = eps_greedy; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") + .set_body_typed(SearchStrategy::EvolutionarySearch); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc new file mode 100644 index 000000000000..e6684d507f2e --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using space generator. */ +class ReplayFuncNode : public SearchStrategyNode { + public: + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayFuncNode* self; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The post processors */ + Array postprocs_{nullptr}; + /*! \brief The space generator for measure candidates generation. */ + SpaceGenerator space_generator_{nullptr}; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `space_generator_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->space_generator_ = tune_context->space_generator.value(); + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->postprocs_ = tune_context->postprocs; + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const TuneContext& tune_context, + const Array& measure_candidates, + const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + Array result; + for (int i = st; i < ed; i++) { + for (;;) { + Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + int design_space_index = tir::SampleInt(&self->rand_state_, 0, schs.size()); + tir::Schedule sch = schs[design_space_index]; + sch->EnterPostproc(); + bool failed = false; + for (const Postproc& proc : self->postprocs_) { + if (!proc->Apply(sch)) { + failed = true; + break; + } + } + if (!failed) { + result.push_back(MeasureCandidate(sch, self->args_info_)); + break; + } + } + } + return result; +} + +inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayFuncNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") + .set_body_typed(SearchStrategy::ReplayFunc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 200eca34133d..2984f36f4514 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { @@ -24,20 +25,18 @@ namespace meta_schedule { /*! \brief A search strategy that generates measure candidates using trace and random decisions. */ class ReplayTraceNode : public SearchStrategyNode { public: - using TRandState = support::LinearCongruentialEngine::TRandState; - /*! \brief The state of the search strategy. */ struct State { /*! \brief The search strategy itself */ ReplayTraceNode* self; /*! \brief The design spaces. */ - Array design_spaces; + Array design_spaces; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int st; /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ int ed; - explicit State(ReplayTraceNode* self, Array design_spaces) + explicit State(ReplayTraceNode* self, Array design_spaces) : self(self), design_spaces(design_spaces), st(0), ed(self->num_trials_per_iter) {} inline Optional> GenerateMeasureCandidates(); @@ -50,9 +49,11 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_total; /*! \brief The module to be tuned. */ - IRModule mod_{nullptr}; + Array per_thread_mod_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; + /*! \brief The post processors */ + Array postprocs_{nullptr}; /*! \brief The number of threads to use. -1 means using logical cpu number. */ int num_threads_ = -1; /*! \brief The random state. -1 means using random number. */ @@ -63,8 +64,9 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("num_trials_total", &num_trials_total); - // `mod_` is not visited + // `per_thread_mod_` is not visited // `args_info_` is not visited + // `postprocs_` is not visited // `num_threads_` is not visited // `rand_state_` is not visited // `state_` is not visited @@ -73,18 +75,30 @@ class ReplayTraceNode : public SearchStrategyNode { static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); - void InitializeWithTuneContext(const TuneContext& context) final { - this->mod_ = context->mod.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); - this->num_threads_ = context->num_threads; - this->rand_state_ = ForkSeed(&context->rand_state); + void InitializeWithTuneContext(const TuneContext& tune_context) final { + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; + this->num_threads_ = tune_context->num_threads; + + this->per_thread_mod_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->per_thread_mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); + } + + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->postprocs_ = tune_context->postprocs; + this->rand_state_ = ForkSeed(&tune_context->rand_state); this->state_.reset(); } void PreTuning(const Array& design_spaces) final { ICHECK(!design_spaces.empty()); ICHECK(this->state_ == nullptr); - this->state_ = std::make_unique(this, design_spaces); + Array design_space_traces; + design_space_traces.reserve(design_spaces.size()); + for (const tir::Schedule& space : design_spaces) { + design_space_traces.push_back(space->trace().value()->Simplified(true)); + } + this->state_ = std::make_unique(this, design_space_traces); } void PostTuning() final { @@ -97,7 +111,9 @@ class ReplayTraceNode : public SearchStrategyNode { return this->state_->GenerateMeasureCandidates(); } - void NotifyRunnerResults(const Array& results) final { + void NotifyRunnerResults(const TuneContext& tune_context, + const Array& measure_candidates, + const Array& results) final { ICHECK(this->state_ != nullptr); this->state_->NotifyRunnerResults(results); } @@ -111,19 +127,20 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure ICHECK_LT(st, ed); std::vector per_thread_rand_state = ForkSeed(&self->rand_state_, self->num_threads_); Array per_task_result(ed - st, MeasureCandidate{nullptr}); - auto f_worker = [this, &per_thread_rand_state, &per_task_result](int thread_id, - int task_id) -> void { + ThreadedTraceApply pp(self->postprocs_); + auto f_worker = [this, &per_thread_rand_state, &per_task_result, &pp](int thread_id, + int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; - int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); - tir::Trace trace = design_spaces[design_space_index]->trace().value(); - tir::Trace new_trace = tir::Trace(trace->insts, {}); - tir::Schedule sch = tir::Schedule::Traced( // - self->mod_, // - /*rand_state=*/ForkSeed(&rand_state), // - /*debug_mode=*/0, // - /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); - new_trace->ApplyToSchedule(sch, /*remove_postproc=*/true); - per_task_result.Set(task_id, MeasureCandidate(sch, self->args_info_)); + IRModule mod = self->per_thread_mod_[thread_id]; + for (;;) { + int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); + tir::Trace trace = design_spaces[design_space_index]; + tir::Trace new_trace = tir::Trace(trace->insts, {}); + if (Optional sch = pp.Apply(mod, new_trace, &rand_state)) { + per_task_result.Set(task_id, MeasureCandidate(sch.value(), self->args_info_)); + break; + } + } }; support::parallel_for_dynamic(0, ed - st, self->num_threads_, f_worker); return per_task_result; @@ -142,7 +159,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") + .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index bc616327eb3b..e9a5f268ec2d 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -86,29 +86,73 @@ class PostOrderApplyNode : public SpaceGeneratorNode { // `sch_rules_` is not visited } - void InitializeWithTuneContext(const TuneContext& context) final { - this->rand_state_ = ForkSeed(&context->rand_state); - CHECK(context->sch_rules.defined()) + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->rand_state_ = ForkSeed(&tune_context->rand_state); + CHECK(tune_context->sch_rules.defined()) << "ValueError: Schedules rules not given in PostOrderApply!"; - this->sch_rules_ = context->sch_rules; + this->sch_rules_ = tune_context->sch_rules; } Array GenerateDesignSpace(const IRModule& mod_) final { using ScheduleAndUnvisitedBlocks = std::pair>; - tir::Schedule sch = tir::Schedule::Traced( // - /*mod=*/mod_, // - /*rand_state=*/ForkSeed(&this->rand_state_), // - /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/0, // tir::kVerifySRefTree | tir::kVerifyCachedFlags /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; - Array result{sch}; + Array result; + Array all_blocks = BlockCollector::Collect(sch), func_blocks, non_func_blocks; + for (const tir::BlockRV& block_rv : all_blocks) { + if (Optional custom_sch_rule_name_opt = + tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule")) { + if (custom_sch_rule_name_opt.value() != "None") { + func_blocks.push_back(block_rv); + } + } else { + non_func_blocks.push_back(block_rv); + } + } + + // only do this once for schedule rules on block annotations + stack.emplace_back(sch, func_blocks); + while (!stack.empty()) { + // get the stack.top() + tir::Schedule sch; + Array blocks; + std::tie(sch, blocks) = stack.back(); + stack.pop_back(); + // if all blocks are visited + if (blocks.empty()) { + result.push_back(sch); + continue; + } + // otherwise, get the last block that is not visited + tir::BlockRV block_rv = blocks.back(); + blocks.pop_back(); + if (sch->HasBlock(block_rv)) { + // pick out the blocks with annotation for customized search space + Optional custom_sch_rule_name_opt = + tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule"); + ICHECK(custom_sch_rule_name_opt.defined() && custom_sch_rule_name_opt.value() != "None"); + String custom_sch_rule_name = custom_sch_rule_name_opt.value(); + const auto* custom_sch_rule_func = runtime::Registry::Get(custom_sch_rule_name); + CHECK(custom_sch_rule_func) << "The given custom schedule function is not defined!"; + Array applied = (*custom_sch_rule_func)(sch, block_rv); + for (const tir::Schedule& sch : applied) { + stack.emplace_back(sch, blocks); + } + } else { + stack.emplace_back(sch, blocks); + } + } + // Enumerate the schedule rules first because you can // always concat multiple schedule rules as one - Array all_blocks = BlockCollector::Collect(sch); for (ScheduleRule sch_rule : sch_rules_) { for (const tir::Schedule& sch : result) { - stack.emplace_back(sch, all_blocks); + stack.emplace_back(sch, non_func_blocks); } result.clear(); diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 3ef5026cae98..72989a20bcd5 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -52,16 +52,23 @@ class RoundRobinNode final : public TaskSchedulerNode { } }; -TaskScheduler TaskScheduler::RoundRobin(Array tasks, // - Builder builder, // - Runner runner, // - Database database) { +TaskScheduler TaskScheduler::RoundRobin(Array tasks, // + Builder builder, // + Runner runner, // + Database database, // + Optional cost_model, // + Optional> measure_callbacks) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; n->runner = runner; n->database = database; + n->cost_model = cost_model; + n->measure_callbacks = measure_callbacks.value_or({}); n->task_id = -1; + for (const TuneContext& task : tasks) { + task->task_scheduler = n.get(); + } return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 08f2b4f451bd..28f95b2dc0dd 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include "../utils.h" namespace tvm { @@ -29,9 +28,9 @@ namespace meta_schedule { * \param candidates The measure candidates. * \return An array of the builder results. */ -Array SendToBuilder(const Builder& builder, // - const TuneContext& context, +Array SendToBuilder(const Builder& builder, const TuneContext& context, const Array& candidates) { + LOG(INFO) << "Sending " << candidates.size() << " sample(s) to builder"; Target target = context->target.value(); Array inputs; inputs.reserve(candidates.size()); @@ -45,14 +44,14 @@ Array SendToBuilder(const Builder& builder, // * \brief Send the built measure candidates to runner. * \param runner The runner to send the candidates to. * \param context The tuning context. - * \param candidates The mesure candidates. + * \param candidates The measure candidates. * \param builder_results The builder results. * \return An array of the runner results. */ -Array SendToRunner(const Runner& runner, // - const TuneContext& context, +Array SendToRunner(const Runner& runner, const TuneContext& context, const Array& candidates, const Array& builder_results) { + LOG(INFO) << "Sending " << candidates.size() << " sample(s) to runner"; Target target = context->target.value(); ICHECK_EQ(candidates.size(), builder_results.size()); int n = candidates.size(); @@ -94,54 +93,60 @@ Array SendToRunner(const Runner& runner, // void TaskSchedulerNode::InitializeTask(int task_id) { TuneContext task = this->tasks[task_id]; - // Derive the values. - IRModule mod = task->mod.value(); - SpaceGenerator space = task->space_generator.value(); - SearchStrategy strategy = task->search_strategy.value(); - // Initialize Modules. - space->InitializeWithTuneContext(task); - strategy->InitializeWithTuneContext(task); + LOG(INFO) << "Initializing task " << task_id << ": " << task->task_name << ", mod =\n" + << tir::AsTVMScript(task->mod); + this->tasks[task_id]->Initialize(); } void TaskSchedulerNode::Tune() { for (int i = 0; i < static_cast(this->tasks.size()); i++) { + TuneContext task = tasks[i]; // Check Optional value validity. - CHECK(tasks[i]->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; - CHECK(tasks[i]->space_generator.defined()) + CHECK(task->mod.defined()) << "ValueError: Require `context.mod`, but it is not defined"; + CHECK(task->space_generator.defined()) << "ValueError: Require `context.space_generator`, but it is not defined"; - CHECK(tasks[i]->search_strategy.defined()) + CHECK(task->search_strategy.defined()) << "ValueError: Require `context.search_strategy`, but it is not defined"; - InitializeTask(i); - - tasks[i]->search_strategy.value()->PreTuning( - tasks[i]->space_generator.value()->GenerateDesignSpace(tasks[i]->mod.value())); + Array design_spaces = + task->space_generator.value()->GenerateDesignSpace(task->mod.value()); + LOG(INFO) << "Total " << design_spaces.size() << " design space(s) generated"; + for (int i = 0, n = design_spaces.size(); i < n; ++i) { + tir::Schedule sch = design_spaces[i]; + tir::Trace trace = sch->trace().value(); + trace = trace->Simplified(true); + LOG(INFO) << "Design space #" << i << ":\n" + << tir::AsTVMScript(sch->mod()) << "\n" + << Concat(trace->AsPython(false), "\n"); + } + task->search_strategy.value()->PreTuning(design_spaces); } int running_tasks = tasks.size(); - while (running_tasks > 0) { - for (int task_id; (task_id = NextTaskId()) != -1;) { - TuneContext task = tasks[task_id]; - ICHECK(!task->is_stopped); - ICHECK(!task->runner_futures.defined()); - SearchStrategy strategy = task->search_strategy.value(); - if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { - Array builder_results = - SendToBuilder(this->builder, task, task->measure_candidates.value()); - task->runner_futures = - SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); - } else { - SetTaskStopped(task_id); - --running_tasks; - } + for (int task_id; (task_id = NextTaskId()) != -1;) { + LOG(INFO) << "Scheduler picks Task #" << task_id + 1 << ": " << tasks[task_id]->task_name; + TuneContext task = tasks[task_id]; + ICHECK(!task->is_stopped); + ICHECK(!task->runner_futures.defined()); + SearchStrategy strategy = task->search_strategy.value(); + if ((task->measure_candidates = strategy->GenerateMeasureCandidates()).defined()) { + Array builder_results = + SendToBuilder(this->builder, task, task->measure_candidates.value()); + task->builder_results = builder_results; + task->runner_futures = + SendToRunner(this->runner, task, task->measure_candidates.value(), builder_results); + } else { + SetTaskStopped(task_id); + --running_tasks; + LOG(INFO) << "Task #" << task_id + 1 << " has finished. Remaining task(s): " << running_tasks; } - int n_tasks = this->tasks.size(); - for (int task_id = 0; task_id < n_tasks; ++task_id) - if (IsTaskRunning(task_id)) { - TuneContext task = tasks[task_id]; - this->JoinRunningTask(task_id); - task->search_strategy.value()->PostTuning(); - } + } + ICHECK_EQ(running_tasks, 0) << "Not all tasks are finished"; + int n_tasks = this->tasks.size(); + for (int task_id = 0; task_id < n_tasks; ++task_id) { + ICHECK(!IsTaskRunning(task_id)) << "Task #" << task_id << " is still running"; + TuneContext task = tasks[task_id]; + task->search_strategy.value()->PostTuning(); } } @@ -175,25 +180,20 @@ void TaskSchedulerNode::JoinRunningTask(int task_id) { for (const RunnerFuture future : task->runner_futures.value()) { results.push_back(future->Result()); } - task->search_strategy.value()->NotifyRunnerResults(results); - task->runner_futures = NullOpt; - // Add to database + task->search_strategy.value()->NotifyRunnerResults(task, task->measure_candidates.value(), + results); + // Invoke the callbacks ICHECK(task->measure_candidates.defined()); - ICHECK(results.size() == task->measure_candidates.value().size()); - int index = 0; - for (const RunnerResult& result : results) { - if (!result->error_msg.defined() && result->run_secs.defined()) { - Optional trace = task->measure_candidates.value()[index]->sch->trace(); - ICHECK(trace.defined()); - this->database->CommitTuningRecord(TuningRecord( - /*trace=*/trace.value(), - /*run_secs=*/result->run_secs.value(), - /*workload=*/this->database->CommitWorkload(task->mod.value()), - /*target=*/task->target.value(), - /*args_info=*/task->measure_candidates.value()[index]->args_info)); - } - index++; + ICHECK(task->builder_results.defined()); + ICHECK_EQ(results.size(), task->measure_candidates.value().size()); + ICHECK_EQ(results.size(), task->builder_results.value().size()); + for (const MeasureCallback& callback : this->measure_callbacks) { + callback->Apply(GetRef(this), task_id, task->measure_candidates.value(), + task->builder_results.value(), results); } + task->measure_candidates = NullOpt; + task->builder_results = NullOpt; + task->runner_futures = NullOpt; } TaskScheduler TaskScheduler::PyTaskScheduler( @@ -201,6 +201,8 @@ TaskScheduler TaskScheduler::PyTaskScheduler( Builder builder, // Runner runner, // Database database, // + Optional cost_model, // + Optional> measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // @@ -212,6 +214,12 @@ TaskScheduler TaskScheduler::PyTaskScheduler( n->builder = builder; n->runner = runner; n->database = database; + n->cost_model = cost_model; + if (measure_callbacks.defined()) { + n->measure_callbacks = measure_callbacks.value(); + } else { + n->measure_callbacks = {}; + } n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index ac85d43e7987..c06cb9adc8ff 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include "./utils.h" @@ -24,21 +23,13 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Constructor function of TuneContext class. - * \param mod The mod to be optimized. - * \param target The target to be optimized for. - * \param space_generator The design space generator. - * \param task_name The name of the tuning task. - * \param rand_state The random state. - * \param num_threads The number of threads to be used. - * \param verbose The verbosity level. - */ TuneContext::TuneContext(Optional mod, // Optional target, // Optional space_generator, // Optional search_strategy, // Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -48,9 +39,11 @@ TuneContext::TuneContext(Optional mod, n->space_generator = space_generator; n->search_strategy = search_strategy; n->sch_rules = sch_rules.value_or({}); - n->task_name = task_name; + n->postprocs = postprocs.value_or({}); + n->mutator_probs = mutator_probs.value_or({}); + n->task_name = task_name.value_or("main"); if (rand_state == -1) { - rand_state = std::random_device()(); + rand_state = support::LinearCongruentialEngine::DeviceRandom(); } support::LinearCongruentialEngine(&n->rand_state).Seed(rand_state); n->num_threads = num_threads; @@ -60,6 +53,26 @@ TuneContext::TuneContext(Optional mod, data_ = std::move(n); } +void TuneContextNode::Initialize() { + if (this->space_generator.defined()) { + this->space_generator.value()->InitializeWithTuneContext(GetRef(this)); + } + if (this->search_strategy.defined()) { + this->search_strategy.value()->InitializeWithTuneContext(GetRef(this)); + } + for (const ScheduleRule& sch_rule : sch_rules) { + sch_rule->InitializeWithTuneContext(GetRef(this)); + } + for (const Postproc& postproc : postprocs) { + postproc->InitializeWithTuneContext(GetRef(this)); + } + if (mutator_probs.defined()) { + for (const auto& kv : mutator_probs) { + kv.first->InitializeWithTuneContext(GetRef(this)); + } + } +} + TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") @@ -68,11 +81,13 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional space_generator, // Optional search_strategy, // Optional> sch_rules, // + Optional> postprocs, // + Optional> mutator_probs, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, sch_rules, task_name, - rand_state, num_threads); + return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, + mutator_probs, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 0a9ce4a1aed9..862d35e3cf90 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -20,30 +20,33 @@ #define TVM_META_SCHEDULE_UTILS_H_ #include +#include #include #include #include #include #include #include +#include +#include #include #include #include #include #include #include -#include -#include #include -#include +#include +#include #include #include #include "../printer/text_printer.h" #include "../support/array.h" #include "../support/base64.h" -#include "../tir/schedule/primitive.h" +#include "../support/utils.h" +#include "../tir/schedule/utils.h" namespace tvm { namespace meta_schedule { @@ -200,7 +203,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed( /*! * \brief Fork a random state into another ones, i.e. PRNG splitting. - * The given random state is also mutated. + * The given random state is also mutated. * \param rand_state The random state to be forked * \param n The number of forks * \return The forked random states @@ -215,6 +218,15 @@ inline std::vector ForkSeed( return results; } +/*! + * \brief Get deep copy of an IRModule. + * \param mod The IRModule to make a deep copy. + * \return The deep copy of the IRModule. + */ +inline IRModule DeepCopyIRModule(IRModule mod) { + return Downcast(LoadJSON(SaveJSON(mod))); +} + /*! * \brief Concatenate strings * \param strs The strings to concatenate @@ -233,6 +245,160 @@ inline std::string Concat(const Array& strs, const std::string& delim) { return os.str(); } +/*! + * \brief Get the BlockRV from a block StmtSRef + * \param sch The schedule + * \param block_sref The block StmtSRef + * \param global_var_name The global variable name + * \return The BlockRV + */ +inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, + const String& global_var_name) { + const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return sch->GetBlock(block->name_hint, global_var_name); +} + +/*! + * \brief Get the number of cores in CPU + * \param target The target + * \return The number of cores. + */ +inline int GetTargetNumCores(const Target& target) { + int num_cores = target->GetAttr("num-cores").value_or(-1); + if (num_cores == -1) { + static const auto* f_cpu_count = runtime::Registry::Get("meta_schedule.cpu_count"); + ICHECK(f_cpu_count) + << "ValueError: Cannot find the packed function \"meta_schedule._cpu_count\""; + num_cores = (*f_cpu_count)(false); + LOG(FATAL) + << "Target does not have attribute \"num-cores\", physical core number must be " + "defined! For example, on the local machine, the target must be \"llvm -num-cores " + << num_cores << "\""; + } + return num_cores; +} + +/*! + * \brief A helper data structure that replays a trace and collects failure counts + * for each postprocessor + */ +struct ThreadedTraceApply { + /*! \brief Constructor */ + explicit ThreadedTraceApply(const Array& postprocs) + : n_(postprocs.size()), items_(new Item[n_]) { + for (int i = 0; i < n_; ++i) { + items_[i].postproc = postprocs[i]; + items_[i].fail_counter = 0; + } + } + + /*! \brief Destructor */ + ~ThreadedTraceApply() { delete[] items_; } + + /*! + * \brief Apply the trace and postprocessors to an IRModule + * \param mod The IRModule to be applied + * \param trace The trace to apply to the IRModule + * \param rand_state The random seed + * \return The schedule created, or NullOpt if any postprocessor fails + */ + Optional Apply(const IRModule& mod, const tir::Trace& trace, + TRandState* rand_state) { + tir::Schedule sch = + tir::Schedule::Traced(mod, + /*rand_state=*/ForkSeed(rand_state), + /*debug_mode=*/0, + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); + trace->ApplyToSchedule(sch, /*remove_postproc=*/true); + sch->EnterPostproc(); + for (int i = 0; i < n_; ++i) { + Item& item = items_[i]; + if (!item.postproc->Apply(sch)) { + ++item.fail_counter; + return NullOpt; + } + } + return sch; + } + + /*! \brief Returns a string summarizing the failures on each postprocessor */ + std::string SummarizeFailures() const { + std::ostringstream os; + for (int i = 0; i < n_; ++i) { + const Item& item = items_[i]; + os << "Postproc #" << i << " [" << item.postproc // + << "]: " << item.fail_counter.load() << " failure(s)"; + if (i != n_ - 1) { + os << "\n"; + } + } + return os.str(); + } + + private: + struct Item { + Postproc postproc{nullptr}; + std::atomic fail_counter{0}; + }; + + int n_; + Item* items_; +}; + +/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ + +/*! + * \brief Reorder the reduction loops to innermost positions if needed. + * \param sch The schedule + * \param block_rv The block where to apply the reorder + * \param fused_reduce_loop The fusion-generated loop to return. + * \param num_spatial_loops The number of spatial loops to return. + * \note Before invoking this helper function, make sure that the block has only spatial and + * reduction loop axes. + */ +inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, + tir::LoopRV* fused_reduce_loop, + size_t* num_spatial_loops) { + Array loops = sch->GetLoops(block_rv); + Array loop_srefs; + for (const tir::LoopRV& loop_rv : loops) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + } + + Array new_order; + // Step 1. Add spatial loops. + *num_spatial_loops = 0; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { + new_order.push_back(loops[i]); + (*num_spatial_loops)++; + } + } + // Step 2. Add reduction loops. + Array reduction_loops; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { + new_order.push_back(loops[i]); + reduction_loops.push_back(loops[i]); + } + } + // Step 3. Apply reordering if new_order differs from the original order. + ICHECK_EQ(new_order.size(), loops.size()); + for (size_t i = 0; i < loops.size(); ++i) { + if (!new_order[i].same_as(loops[i])) { + sch->Reorder(new_order); + break; + } + } + // Step 4. Fuse all the reduction loops if there are multiple reduction loops. + CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop"; + if (reduction_loops.size() > 1) { + *fused_reduce_loop = sch->Fuse(reduction_loops); + } else { + *fused_reduce_loop = reduction_loops[0]; + } +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/support/array.h b/src/support/array.h index 95b4f58a2e22..218150f9dba0 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -100,6 +100,29 @@ inline Array AsArray(const ShapeTuple& shape) { return result; } +/*! + * \brief Concatenate a list of arrays into a single array + * \tparam T The type of elements in the arrays + * \tparam Iterator The type of the iterator into the list of arrays + * \param begin The begin iterator to the array list + * \param end The end iterator to the array list + * \return The concatenated array + */ +template +inline Array ConcatArrayList(Iterator begin, Iterator end) { + int size = 0; + for (Iterator it = begin; it != end; ++it) { + size += (*it).size(); + } + Array result; + result.reserve(size); + for (Iterator it = begin; it != end; ++it) { + const auto& item = *it; + result.insert(result.end(), item.begin(), item.end()); + } + return result; +} + /********** Implementation details of AsVector **********/ namespace details { diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index ae4a0386d404..46bbd2bceb9a 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -144,6 +144,29 @@ inline NDIntSet NDIntSetEval( return ret; } +/*! + * \brief Output the N-dimensional integer set to a stream. + * \param os The output stream. + * \param nd_int_set The N-dimensional integer set to be output. + * \return The output stream. + */ +inline std::ostream& operator<<(std::ostream& os, const NDIntSet& nd_int_set) { + os << '['; + bool is_first = true; + for (const arith::IntSet& int_set : nd_int_set) { + if (is_first) { + is_first = false; + } else { + os << ", "; + } + PrimExpr min = int_set.min(); + PrimExpr max = int_set.max(); + os << min << ":" << max; + } + os << ']'; + return os; +} + } // namespace support } // namespace tvm diff --git a/src/target/tag.cc b/src/target/tag.cc index a931a288924e..39f8f37aff2b 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -70,6 +70,30 @@ Target TargetTag::AddTag(String name, Map config, bool overri /********** Register Target tags **********/ +TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-64") + .set_config({{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array{"+neon"}}, + {"num-cores", Integer(4)}, + {"host", Map{{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("cortex-a72")}, + {"mattr", Array{"+neon"}}, + {"num-cores", Integer(4)}}}}); + +TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") + .set_config({{"kind", String("cuda")}, + {"arch", String("sm_72")}, + {"shared_memory_per_block", Integer(49152)}, + {"registers_per_block", Integer(65536)}, + {"max_threads_per_block", Integer(1024)}, + {"thread_warp_size", Integer(32)}, + {"host", Map{{"kind", String("llvm")}, + {"mtriple", String("aarch64-linux-gnu")}, + {"mcpu", String("carmel")}, + {"num-cores", Integer(4)}}}}); + #define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \ TVM_REGISTER_TARGET_TAG(Name).set_config({ \ {"kind", String("cuda")}, \ @@ -318,7 +342,6 @@ TVM_REGISTER_CUDA_TAG("nvidia/geforce-gt-415m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-gtx-480m", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-710m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-410m", "sm_21", 49152, 32768); -TVM_REGISTER_CUDA_TAG("nvidia/jetson-agx-xavier", "sm_72", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/jetson-nano", "sm_53", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx2", "sm_62", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/jetson-tx1", "sm_53", 49152, 32768); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e4bf48b2a51e..c562c78bd187 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -254,6 +254,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mabi") .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("num-cores") .add_attr_option("link-params", Bool(false)) .add_attr_option("unpacked-api") .add_attr_option("interface-api") diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index dc1ed1c193e8..d01788e92c4c 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -92,7 +92,7 @@ class GPUCodeVerifier : public StmtExprVisitor { const auto* extent = op->value.as(); ICHECK(extent); - std::string name = var.get()->name_hint; + std::string name = op->node.as()->thread_tag; // record the number of threads in a block if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || name == "vthread") { @@ -151,6 +151,7 @@ class GPUCodeVerifier : public StmtExprVisitor { errors_.push_back(s.str()); } }; + err("threads per block", thread_per_block_, max_threads_per_block_); err("local memory per block", local_memory_per_block_, max_local_memory_per_block_); err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 101d80a52ea1..a9234d6fe85b 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -21,9 +21,13 @@ * \file src/tir/ir/function.cc * \brief The function data structure. */ +#include #include #include #include +#include + +#include "../../support/nd_int_set.h" namespace tvm { namespace tir { @@ -64,6 +68,195 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); +Array IndexMapNode::Apply(const Array& inputs) const { + CHECK_EQ(inputs.size(), this->src_iters.size()); + arith::Analyzer analyzer; + int n = inputs.size(); + for (int i = 0; i < n; ++i) { + analyzer.Bind(this->src_iters[i], inputs[i]); + } + Array results; + results.reserve(this->tgt_iters.size()); + for (PrimExpr result : this->tgt_iters) { + results.push_back(analyzer.Simplify(std::move(result))); + } + return results; +} + +Array IndexMapNode::MapShape(const Array& shape) const { + using namespace support; + Array indices; + std::unordered_map dom_map; + for (const PrimExpr dim : shape) { + Var var; + indices.push_back(var); + dom_map.emplace(var.get(), arith::IntSet::FromMinExtent(0, dim)); + } + Array mapped_indices = Apply(indices); + NDIntSet nd_int_set = NDIntSetFromPoint(mapped_indices); + nd_int_set = NDIntSetEval(nd_int_set, dom_map); + Array new_shape; + for (const auto& int_set : nd_int_set) { + ICHECK(is_zero(int_set.min())); + new_shape.push_back(int_set.max() + 1); + } + auto fmul = [](PrimExpr a, PrimExpr b, Span span) { return a * b; }; + PrimExpr old_size = foldl(fmul, Integer(1), shape); + PrimExpr new_size = foldl(fmul, Integer(1), new_shape); + + arith::Analyzer analyzer; + CHECK(analyzer.CanProveEqual(old_size, new_size)) + << "ValueError: The size of the new shape after IndexMap " << new_shape + << " doesn't match the size of the original shape " << shape; + return new_shape; +} + +String IndexMapNode::ToPythonString() const { + std::unordered_set used_names; + Map var_remap; + for (const Var& src_iter : src_iters) { + if (used_names.count(src_iter->name_hint)) { + std::string new_name = src_iter->name_hint + std::to_string(used_names.size()); + used_names.insert(new_name); + var_remap.Set(src_iter, Var(new_name)); + } else { + used_names.insert(src_iter->name_hint); + } + } + std::ostringstream oss; + oss << "lambda "; + for (size_t i = 0; i < src_iters.size(); ++i) { + if (i != 0) { + oss << ", "; + } + auto it = var_remap.find(src_iters[i]); + if (it != var_remap.end()) { + oss << (*it).second; + } else { + oss << src_iters[i]; + } + } + oss << ": ("; + for (size_t i = 0; i < tgt_iters.size(); ++i) { + if (i != 0) { + oss << ", "; + } + oss << Substitute(tgt_iters[i], var_remap); + } + if (tgt_iters.size() == 1) { + oss << ","; + } + oss << ")"; + return String(oss.str()); +} + +IndexMap::IndexMap(Array src_iters, Array tgt_iters) { + ObjectPtr n = make_object(); + n->src_iters = std::move(src_iters); + n->tgt_iters = std::move(tgt_iters); + data_ = std::move(n); +} + +IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc(Array)> func) { + Array src_iters; + src_iters.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + src_iters.push_back(Var("i" + std::to_string(i), DataType::Int(32))); + } + return IndexMap(src_iters, func(src_iters)); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + const auto* n = node.as(); + ICHECK(n); + p->stream << "IndexMap: ("; + for (int i = 0, total = n->src_iters.size(); i < total; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << n->src_iters[i]; + } + p->stream << ") => "; + p->stream << "("; + for (int i = 0, total = n->tgt_iters.size(); i < total; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << n->tgt_iters[i]; + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(IndexMapNode); +TVM_REGISTER_GLOBAL("tir.IndexMap") + .set_body_typed([](Array src_iters, Array tgt_iters) { + return IndexMap(src_iters, tgt_iters); + }); +TVM_REGISTER_GLOBAL("tir.IndexMapFromFunc").set_body_typed(IndexMap::FromFunc); +TVM_REGISTER_GLOBAL("tir.IndexMapApply").set_body_method(&IndexMapNode::Apply); + +TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) { + // check the number of func var is equal + CHECK_EQ(desc_func->params.size(), intrin_func->params.size()); + CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size()); + + // check both functions' bodies are directly block + const auto* desc_realize = + Downcast(desc_func->body)->block->body.as(); + const auto* intrin_realize = + Downcast(intrin_func->body)->block->body.as(); + CHECK(desc_realize != nullptr) << "description function's body expect a directly block"; + CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block"; + + const Block& desc_block = desc_realize->block; + const Block& intrin_block = intrin_realize->block; + + // check block var number and iter type + CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size()) + << "Two blocks should have the same number of block vars"; + for (size_t i = 0; i < desc_block->iter_vars.size(); i++) { + const IterVar& desc_var = desc_block->iter_vars[i]; + const IterVar& intrin_var = intrin_block->iter_vars[i]; + CHECK(desc_var->iter_type == intrin_var->iter_type) + << "Block iter_type mismatch between " << desc_var->iter_type << " and " + << intrin_var->iter_type; + } + + auto n = make_object(); + n->description = std::move(desc_func); + n->implementation = std::move(intrin_func); + data_ = std::move(n); +} + +class TensorIntrinManager { + public: + Map reg; + + static TensorIntrinManager* Global() { + static TensorIntrinManager* inst = new TensorIntrinManager(); + return inst; + } +}; + +TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc intrin_func) { + TensorIntrinManager* manager = TensorIntrinManager::Global(); + ICHECK_EQ(manager->reg.count(name), 0) + << "ValueError: TensorIntrin '" << name << "' has already been registered"; + TensorIntrin intrin(desc_func, intrin_func); + manager->reg.Set(name, intrin); + return intrin; +} + +TensorIntrin TensorIntrin::Get(String name) { + const TensorIntrinManager* manager = TensorIntrinManager::Global(); + ICHECK_EQ(manager->reg.count(name), 1) + << "ValueError: TensorIntrin '" << name << "' is not registered"; + return manager->reg.at(name); +} + +TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { // TODO(tvm-team) redirect to Text printer once we have a good text format. @@ -85,5 +278,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc") return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }); +TVM_REGISTER_GLOBAL("tir.TensorIntrin") + .set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) { + return TensorIntrin(desc_func, intrin_func); + }); + +TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register); +TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..d3342fafdc06 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -700,10 +700,11 @@ Array Substitute(const Array& region, const Map& vm } void PreOrderVisit(const ObjectRef& stmt_or_expr, - const std::function& fvisit) { + const std::function& fvisit, bool visit_init_block) { class PreOrderVisitor : public StmtExprVisitor { public: - explicit PreOrderVisitor(const std::function& f) : f_(f) {} + explicit PreOrderVisitor(const std::function& f, bool visit_init_block) + : f_(f), visit_init_block_(visit_init_block) {} private: void VisitExpr(const PrimExpr& expr) final { @@ -726,11 +727,35 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr, } } + void VisitStmt_(const BlockNode* op) final { + auto fvisit_buffer_region = [this](const BufferRegion& s) { + for (const auto& range : s->region) { + this->VisitExpr(range->min); + this->VisitExpr(range->extent); + } + }; + VisitArray(op->iter_vars, [this](const IterVar& iter_var) { + this->VisitExpr(iter_var->dom->min); + this->VisitExpr(iter_var->dom->extent); + }); + VisitArray(op->reads, fvisit_buffer_region); + VisitArray(op->writes, fvisit_buffer_region); + VisitArray(op->match_buffers, + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + fvisit_buffer_region(match_buffer_region->source); + }); + if (visit_init_block_ && op->init.defined()) { + this->VisitStmt(op->init.value()); + } + this->VisitStmt(op->body); + } + const std::function& f_; + bool visit_init_block_; std::unordered_set visited_; }; - PreOrderVisitor visitor(fvisit); + PreOrderVisitor visitor(fvisit, visit_init_block); if (const auto* stmt = stmt_or_expr.as()) { visitor(GetRef(stmt)); } else if (const auto* expr = stmt_or_expr.as()) { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ae72d592339f..35c5ed76ccfd 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -20,6 +20,7 @@ #define TVM_TIR_SCHEDULE_ANALYSIS_H_ #include +#include #include #include @@ -69,6 +70,26 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl */ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); +/*! + * \brief The information of a block scope, including the leaf blocks, + * as well as the loop types (spatial, reduction) for each loop in the scope. + */ +struct ScopeBlockLoopInfo { + /*! \brief A list of the leaf blocks, from left to right */ + std::vector realizes; + /*! \brief The loop vars bound to spatial block iters */ + std::unordered_set spatial_vars; + /*! \brief The loop vars bound to non-spatial block iters */ + std::unordered_set non_spatial_vars; +}; + +/*! + * \brief Inspect the scope of the given sref + * \param scope_block The root block of the scope + * \return The information of the scope + */ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it @@ -174,6 +195,27 @@ bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check if the block is a data parallel block, i.e. all the block vars are data parallel + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a data parallel block + */ +bool IsSpatial(const StmtSRef& block_sref); + +/*! + * \brief Extracts the types of the block vars + * \param block_sref The block to be checked + * \return A vector of types of the block vars + */ +std::vector GetBlockVarTypes(const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be considered as a "write cache" + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a write cache + */ +bool IsWriteCache(const StmtSRef& block_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -195,6 +237,15 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va */ void CheckAffineBinding(const ScheduleState& self, Block block); +/*! + * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, + * from outer to inner. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block has a trivial binding + */ +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path @@ -266,6 +317,145 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the IterVarType of the specific loop, according to the blocks it's bound to + * \param loop_sref The loop to be checked + * \return The IterVarType of the specific loop + */ +IterVarType GetLoopIterType(const StmtSRef& loop_sref); + +/*! + * \brief Check whether the loop/block has only one child + * \param loop_or_block_sref The loop/block to be checked + * \return Whether the loop/block has only one child + */ +bool HasSingleChild(const StmtSRef& loop_or_block_sref); + +/*! + * \brief Get the lowest common ancestor of an array of blocks or loops on the sref tree + * \param srefs The block srefs or loop srefs whose lowest common ancestor is to be queried + * \return The lowest common ancestor of the input block srefs or loop srefs + * \note The input array is required to have at least one sref + */ +StmtSRef GetSRefLowestCommonAncestor(const Array& srefs); + +/*! + * \brief Collect all the feasible compute-at locations of the input block + * \param self The schedule state + * \param block_sref The block whose compute-at locations are to be collected + * \return All the feasible compute-at locations of the input block, given as an array of loop srefs + * and an array of their indices among the outer loops of the input block + */ +std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, + const StmtSRef& block_sref); + +/******** Tensorization ********/ + +/* \brief Deep comparison to check if two IR graph are equivalent */ +using ExprComparator = ExprFunctor; +using StmtComparator = StmtFunctor; + +class TensorizeComparator : public ExprComparator, public StmtComparator { + public: + explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {} + + // Map from rhs buffer to lhs buffer + std::unordered_map rhs_buffer_map_; + // Buffer indices mapping + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; + std::vector extra_block_vars_; + // variable remap if any + std::unordered_map equal_map_; + + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; + bool VisitStmt(const Stmt& n, const Stmt& other) override; + + bool VisitStmt_(const ForNode* op, const Stmt& other) override; + bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; + bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; + bool VisitStmt_(const BlockNode* op, const Stmt& other) override; + + bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; + bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; + bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; + bool VisitExpr_(const NENode* op, const PrimExpr& other) override; + bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const LENode* op, const PrimExpr& other) override; + bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; + bool VisitExpr_(const GENode* op, const PrimExpr& other) override; + bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; + bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; + bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; + bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; + bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; + bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; + + bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs); + virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); + bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); + bool CompareAnnotation(const std::pair& lhs, + const std::pair& rhs); + bool CompareAnnotationMap(const Map& lhs, const Map& rhs); + template + bool CompareBufferAccess(const T* lhs, const T* rhs); + template + bool CompareArray(const Array& lhs, const Array& rhs, F cmp); + bool CompareRange(const Range& lhs, const Range& rhs); + bool CompareType(const DataType& lhs, const DataType& rhs); + + protected: + bool assert_mode_; + bool is_scope_block = true, is_inner_block = true; +}; + +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps block loops to desc loops */ + Map loop_map; + /*! \brief Maps loops in desc to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +/*! + * \brief Managed reference to TensorizeInfoNode + * \sa TensorizeInfoNode + */ +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +/*! + * \brief Check if the given block can be tensorized, and in the meantime gather the necessary + * information for tensorization + * \param self The schedule state + * \param block_sref The block to be analyzed + * \param desc_func The target function for tensorization + * \return The necessary information used for tensorization, or NullOpt if the block cannot be + * tensorized + */ +TVM_DLL Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + /******** Producer-consumer relation ********/ /*! @@ -328,6 +518,16 @@ struct ProducerConsumerSplit { */ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write); +/*! + * \brief Find the defining site of the buffer in the given block and its ancestors + * \param block_sref The block sref + * \param buffer The buffer + * \return The defining site of the buffer and whether the buffer is allocated (otherwise the + * buffer is from match_buffer). + */ +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer); + /******** Reduction Block Related ********/ /*! @@ -395,6 +595,80 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, /******** Misc ********/ +/*! + * \brief Given the read/write region, extract the pattern of their index correspondence + * namely, the mapping from read index to the write index. + * \param read_region The read region + * \param write_region The write region + * \return A tuple of booleans, the extracted pattern + * 0) exists: if the pattern is found + * 1) surjective: if the pattern is surjective, i.e. each write index is mapped at least once + * e.g. A[i, j] = B[i, i, j] + * 2) injective: if the pattern is injective, i.e. each write index is mapped at most once. + * e.g. A[i, j] = B[i] + * 3) ordered: if the mapping is ordered + * 4) no_const_read: if there is no constant indexing in the read indices, + * e.g. A[i, j] = B[0, i, j] + * 5) no_shift_read: if there is no constant shift in the read indices, + * e.g. A[i, j] = B[i + 1, j] + */ +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region); + +/*! + * \brief Checks if the given block has data reuse opportunity and thus multi-level tiling is + * beneficial. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has data reuse opportunity + */ +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if the given block has been applied by multi-level tiling. We check this by examine + * the block's annotation. + * \param block_sref The block to be checked + * \return A boolean indicating whether the block has been multi-level tiled. + */ +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref); + +/*! + * \brief Checks if the rfactor or cross thread reduction is beneficial to the given block. + * \param self The schedule state. + * \param block_sref The block to be checked. + * \param max_parallel_extent The maximum parallel jobs on the target. + * \param max_parallel_extent The maximum cores on the target. + * \return A boolean indicating whether the operation is beneficial. + */ +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic); + +/*! + * \brief Checks if the given AST contains the specific operators + * \param stmt The AST to be checked + * \param ops The list of operators to be checked + * \return A boolean indicating whether the AST contains the specific operators + */ +bool HasOp(const Stmt& stmt, const Array& ops); + +/*! + * \brief Checks if the given AST contains if-then-else, including + * 1) IfThenElse statement + * 2) Select expression + * 3) The operator `tir.if_then_else` + * 4) Block predicates + */ +bool HasIfThenElse(const Stmt& stmt); + +/******** Storage Scope ********/ + /*! * \brief Check whether the input storage scope string is valid. Throw an error if not. * \param self The schedule state @@ -420,27 +694,17 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); /*! - * \brief Checks if a producer block could be successfully computed at the specific loop. - * \param self The schedule state - * \param block_sref The block to be moved - * \param loop_sref The loop where the block to be moved to - * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 - * \return A boolean indicating whether the block could be successfully compute at the specific loop + * \brief Provided the access pattern to a buffer, suggest one of the possible layout + * transformation to minimize the locality of the access pattern. + * \param buffer The buffer to be transformed + * \param indices The access pattern to the buffer + * \param loops The loops above the buffer + * \param predicate The predicate of the access + * \param analyzer Arithmetic analyzer */ -bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, - bool preserve_unit_loops); - -/*! - * \brief Checks if a consumer block could be successfully computed at the specific loop. - * \param self The schedule state - * \param block_sref The block to be moved - * \param loop_sref The loop where the block to be moved to - * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 - * \return A boolean indicating whether the block could be successfully reverse compute at the - * specific loop - */ -bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& loop_sref, bool preserve_unit_loops); +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 0a7d57effd0d..a6bb3f3e17b6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -47,6 +49,37 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl /******** Scope ********/ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { + struct Collector : public StmtVisitor { + void VisitStmt_(const BlockRealizeNode* realize) final { + result.realizes.push_back(GetRef(realize)); + const Array& iter_vars = realize->block->iter_vars; + const Array& iter_values = realize->iter_values; + ICHECK_EQ(iter_vars.size(), iter_values.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = iter_vars[i]; + const PrimExpr& iter_value = iter_values[i]; + std::unordered_set* vars = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + vars = &result.spatial_vars; + } else { + vars = &result.non_spatial_vars; + } + PostOrderVisit(iter_value, [vars](const ObjectRef& obj) { + if (const VarNode* var = obj.as()) { + vars->insert(var); + } + }); + } + } + + ScopeBlockLoopInfo result; + } visitor; + visitor(scope_block->body); + return std::move(visitor.result); +} + StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, // bool require_stage_pipeline, // bool require_subtree_compact_dataflow) { @@ -408,6 +441,43 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +bool IsSpatial(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != IterVarType::kDataPar) { + return false; + } + } + return true; +} + +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + std::vector results; + results.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + results.push_back(iter_var->iter_type); + } + return results; +} + +bool IsWriteCache(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1) { + return false; + } + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool exists, surjective, injective, ordered, no_const_read, no_shift_read; + std::tie(exists, surjective, injective, ordered, no_const_read, no_shift_read) = + AnalyzeReadWritePattern(read_region, write_region); + if (!(injective && ordered)) { + return false; + } + } + return true; +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, @@ -457,6 +527,22 @@ void CheckAffineBinding(const ScheduleState& self, Block block) { } } +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = GetLoops(block_sref); + Array binds = GetBlockRealize(self, block_sref)->iter_values; + if (loops.size() != binds.size()) { + return false; + } + for (int i = 0, n = loops.size(); i < n; ++i) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + if (binds[i].get() != loop->loop_var.get()) { + return false; + } + } + return true; +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { @@ -646,6 +732,335 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +IterVarType GetLoopIterType(const StmtSRef& loop_sref) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + const Var& loop_var = loop->loop_var; + int n_spatial = 0; + int n_reduce = 0; + int n_other = 0; + auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { + if (const auto* realize = obj.as()) { + const BlockNode* block = realize->block.get(); + // Number of block vars and their bindings + ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); + int n = realize->iter_values.size(); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + // Categorize the current block var + int* ref = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + ref = &n_spatial; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + ref = &n_reduce; + } else { + ref = &n_other; + } + // Visit the binding to see if `loop_var` appears + PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void { + if (obj.same_as(loop_var)) { + (*ref) += 1; + } + }); + } + return false; + } + return true; + }; + PreOrderVisit(loop->body, f_visit); + if (n_other) { + return IterVarType::kOpaque; + } else if (n_spatial && n_reduce) { + return IterVarType::kOpaque; + } else if (n_reduce) { + return IterVarType::kCommReduce; + } else { + return IterVarType::kDataPar; + } +} + +bool HasSingleChild(const StmtSRef& loop_or_block_sref) { + const StmtNode* body = nullptr; + if (const auto* loop = loop_or_block_sref->StmtAs()) { + body = loop->body.get(); + } else if (const auto* block = loop_or_block_sref->StmtAs()) { + body = block->body.get(); + } else { + LOG(FATAL) << "TypeError: Unable to recognize the type of `loop_or_block_sref`: " + << loop_or_block_sref->stmt->GetTypeKey(); + } + if (body->IsInstance()) { + const auto* seq_stmt = static_cast(body); + return seq_stmt->seq.size() == 1; + } + return true; +} + +StmtSRef GetSRefLowestCommonAncestor(const Array& srefs) { + CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref"; + + std::unordered_map sref_visited_cnt; + for (const StmtSRef& sref : srefs) { + const StmtSRefNode* p = sref.get(); + while (p != nullptr) { + ++sref_visited_cnt[p]; + p = p->parent; + } + } + + int n_sref = static_cast(srefs.size()); + const StmtSRefNode* p = srefs[0].get(); + while (p != nullptr && sref_visited_cnt[p] != n_sref) { + p = p->parent; + } + ICHECK(p != nullptr); + return GetRef(p); +} + +std::pair, std::vector> CollectComputeLocation(const ScheduleState& self, + const StmtSRef& block_sref) { + Array location_srefs; + std::vector location_indices; + + // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can + // be inlined. + if (CanComputeInline(self, block_sref)) { + location_srefs.push_back(StmtSRef::InlineMark()); + location_indices.push_back(-2); + } + location_srefs.push_back(StmtSRef::RootMark()); + location_indices.push_back(-1); + + // Step 2. If the block has no consumer, there is no more candidate. + Array consumers = GetConsumers(self, block_sref); + if (consumers.empty()) { + return std::make_pair(location_srefs, location_indices); + } + // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If + // such a loop cannot be found, there is no more candidate and we just return. + StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers) + : GetRef(consumers[0]->parent); + if (loop_boundary->StmtAs() != nullptr) { + return std::make_pair(location_srefs, location_indices); + } + + // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position + // of the boundary loop reveals the number of possible additional candidates. + Array loop_srefs = GetLoops(consumers[0]); + int lca_pos = std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); + ICHECK_LT(lca_pos, static_cast(loop_srefs.size())); + int n_candidate = lca_pos + 1; + + // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This + // position is used for removing the unwanted candidates from the perspective of performance. + std::vector loop_iter_types; + loop_iter_types.reserve(n_candidate); + int i_last_datapar = -1; + for (int i = 0; i < n_candidate; ++i) { + IterVarType iter_type = GetLoopIterType(loop_srefs[i]); + loop_iter_types.push_back(iter_type); + if (iter_type == IterVarType::kDataPar) { + i_last_datapar = i; + } + } + // Step 6. Check and add the candidates in turn according to the following rules: + // - skip the unit loops (loops with extent 1); + // - do not consider the data-parallel loops after a not-data-parallel loop; + // - do not consider the trailing not-data-parallel loops. + location_srefs.reserve(n_candidate + 2); + location_indices.reserve(n_candidate + 2); + bool visited_reduce = false; + for (int i = 0; i < n_candidate; ++i) { + const int64_t* loop_extent = GetLoopIntExtent(loop_srefs[i]); + if (loop_extent != nullptr && *loop_extent == 1) { + continue; + } + + if (loop_iter_types[i] == IterVarType::kDataPar) { + if (visited_reduce) { + break; + } + } else { + visited_reduce = true; + if (i > i_last_datapar) { + break; + } + } + + location_srefs.push_back(loop_srefs[i]); + location_indices.push_back(i); + } + + return std::make_pair(location_srefs, location_indices); +} + +/******** Tensorization ********/ + +class AutoTensorizeComparator : public tir::TensorizeComparator { + public: + AutoTensorizeComparator() : tir::TensorizeComparator(false) {} + + bool VisitStmt(const tir::Stmt& n, const tir::Stmt& rhs) override { + if (n.same_as(rhs)) return true; + tir::Stmt lhs = n; + if (lhs->type_index() != rhs->type_index()) { + return false; + } + bool equal = tir::StmtComparator::VisitStmt(lhs, rhs); + ICHECK(equal || !assert_mode_) << "Statements are not matching between:\n" + << n << "\nand\n" + << rhs; + return equal; + } + + bool CompareBuffer(const tir::Buffer& lhs, const tir::Buffer& rhs) override { + if (lhs.same_as(rhs)) return true; + // Remap both buffer itself and buffer data + // Skip buffer shape + bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && + lhs->buffer_type == rhs->buffer_type && CompareType(lhs->dtype, rhs->dtype); + if (equal) rhs_buffer_map_[rhs] = lhs; + return equal; + } +}; + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + // Try to do tiling automatically if possible + // Now the heuristic is that if block's block var binding is constant + loop var, + // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder + // i, j, k according to the loops outside desc_block + // Collect the loops outside block + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract the loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + const auto* desc_body = + Downcast(desc_func->body)->block->body.as(); + ICHECK(desc_body); + tir::PostOrderVisit(desc_body->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + if (!AutoTensorizeComparator().VisitStmt(block, GetRef(desc_block))) { + return NullOpt; + } + // Step 3. Extract the loops on top of the block. It is a mirror step of Step 1 + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 4. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + int n_block_vars = block->iter_values.size(); + int n_desc_vars = desc_block->iter_values.size(); + int offset = n_block_vars - n_desc_vars; + if (offset < 0) { + return NullOpt; + } + // We align the block and desc block's bindings from the right side + // block (v0=..., v1=..., v2=...) + // ^ i_block + // desc_block( v1=..., v2=...) + // ^ i_desc + for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) { + // For each block var binding, we find + const PrimExpr& block_bind = block->iter_values[i_block]; + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + // Step 4.1. Find the corresponding loop of the i-th block var of block + const tir::ForNode* block_loop = nullptr; + for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (!UsesVar(r, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { + block_loop = block_loops[i]; + break; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.2. Find the corresponding loop of the i-th block var of desc + const tir::ForNode* desc_loop = nullptr; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(r, + [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { + desc_loop = desc_loops[i]; + break; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.3. Check divisibility of loop extents + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + if (const auto* int_block_extent = block_extent.as()) { + if (const auto* int_desc_extent = desc_extent.as()) { + if (int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + } else { + return NullOpt; + } + } else { + return NullOpt; + } + // Step 4.4. Maps the result of Step 4.1 to Step 4.2 + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + auto it = ret->loop_map.find(block_loop_sref); + if (it == ret->loop_map.end()) { + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + } else if ((*it).second.get() != desc_loop) { + return NullOpt; + } + } + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + /******** Producer-consumer relation ********/ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { @@ -819,6 +1234,37 @@ Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, return access_region[n]->buffer; } +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { + // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or + // match_buffers. + const StmtSRefNode* defining_site_sref = block_sref.get(); + while (defining_site_sref != nullptr) { + const auto* block = defining_site_sref->StmtAs(); + // If this sref is not a block sref, skip it. + if (block == nullptr) { + defining_site_sref = defining_site_sref->parent; + continue; + } + // Try to find the buffer in `alloc_buffers` + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (buffer.same_as(alloc_buffer)) { + return {GetRef(defining_site_sref), true}; + } + } + // Try to find the buffer in `match_buffers` + for (const MatchBufferRegion match_buffer : block->match_buffers) { + if (buffer.same_as(match_buffer)) { + return {GetRef(defining_site_sref), false}; + } + } + defining_site_sref = defining_site_sref->parent; + } + // If we cannot find the defining site block, it means that the buffer must be in the function's + // buffer_map, which isn't an intermediate buffer. + return {NullOpt, false}; +} + /******** Pattern Matcher ********/ /*! @@ -1345,6 +1791,311 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +/******** Misc ********/ + +std::tuple +AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) { + static constexpr const std::tuple kNotExist = { + false, false, false, false, false, false}; + // Step 1. Extract the write indices + int w_dim = write_region->buffer->shape.size(); + std::unordered_map var2idx; + var2idx.reserve(w_dim); + for (int i = 0; i < w_dim; ++i) { + const Range& dom = write_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + if (const auto* v = dom->min.as()) { + var2idx.emplace(v, i); + } else { + return kNotExist; + } + } + // Step 2. Map each read index to a write index + bool no_const_read = true; + bool no_shift_read = true; + int r_dim = read_region->buffer->shape.size(); + std::vector mapped(r_dim, -1); + for (int i = 0; i < r_dim; ++i) { + const Range& dom = read_region->region[i]; + if (as_const_int(dom->extent) == nullptr) { + return kNotExist; + } + // Case 1. Read index is a constant + if (as_const_int(dom->min) != nullptr) { + no_const_read = false; + continue; + } + // Case 2. Read index cannot be recognized as `var +/- const` + // where `var` is a write index and `const` is an optional constant shift + Optional opt_const = NullOpt; + const VarNode* var = + static_cast(AnalyzeVarWithShift(dom->min, &opt_const).get()); + if (var == nullptr || !var2idx.count(var)) { + return kNotExist; + } + // Case 3. Read index is `var +/- const` + mapped[i] = var2idx.at(var); + if (opt_const.defined()) { + no_shift_read = false; + } + } + // Step 3. Check if the mapping is ordered, and count how many times each var is mapped + std::vector mapped_counter(w_dim, 0); + bool ordered = true; + int last_mapped = -1; + for (int i : mapped) { + if (i != -1) { + ++mapped_counter[i]; + if (last_mapped != -1 && last_mapped > i) { + ordered = false; + } + last_mapped = i; + } + } + // Step 4. Check if the mapping is surjective or injective + // Surjective: each write index is mapped at least once + // Injective: each write index is mapped at most once + bool surjective = true; + bool injective = true; + for (int cnt : mapped_counter) { + if (cnt == 0) { + surjective = false; + } else if (cnt >= 2) { + injective = false; + } + } + return {/*exist=*/true, surjective, injective, ordered, no_const_read, no_shift_read}; +} + +bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || + !IsTrivialBinding(self, block_sref)) { + return false; + } + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + // Step 1. Sort out spatial block variables + std::vector spatial_block_vars; + spatial_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& block_var : block->iter_vars) { + if (block_var->iter_type == IterVarType::kDataPar) { + spatial_block_vars.push_back(block_var->var.get()); + } + } + // Step 2. Enumerate each read region, check the number of block vars that are not used + // to index the read region + int total_unused_block_vars = 0; + std::unordered_set read_buffers; + read_buffers.reserve(block->reads.size()); + for (const BufferRegion& buffer_region : block->reads) { + const BufferNode* buffer = buffer_region->buffer.get(); + const Array& regions = buffer_region->region; + // Step 2.1. Duplication of read buffers are not allowed + if (read_buffers.insert(buffer).second == false) { + return false; + } + // Step 2.2. Skip the reduction buffer + if (buffer == write_buffer) { + continue; + } + // Step 2.3. Collect the block vars that are used to index the read region + std::unordered_set vars; + for (const Range& range : regions) { + if (as_const_int(range->extent) == nullptr) { + return false; + } + for (const Var& var : UndefinedVars(range->min)) { + vars.insert(var.get()); + } + } + // Step 2.4. Check if the block vars are not used to index the read region + int n_unused_block_vars = 0; + for (const VarNode* block_var : spatial_block_vars) { + if (vars.count(block_var) == 0) { + ++n_unused_block_vars; + } + } + total_unused_block_vars += n_unused_block_vars; + } + return total_unused_block_vars >= 1; +} + +bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { + return tir::GetAnn(block_sref, tir::attr::meta_schedule_tiling_structure).defined(); +} + +std::pair GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref) { + Array loops = tir::GetLoops(block_sref); + int64_t cum_space_len = 1, cum_reduce_len = 1; + /* + * Return (-1, -1) if + * 1. there is some loop with type other than kDataPar and kCommReduce; + * 2. there is some loop which is dynamic. + */ + for (const tir::StmtSRef& loop_sref : loops) { + tir::IterVarType type = GetLoopIterType(loop_sref); + if (type == tir::kDataPar) { + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (*extent != -1) { + cum_space_len *= *extent; + } else { + return std::make_pair(-1, -1); + } + } else if (type == tir::kCommReduce) { + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (*extent != -1) { + cum_reduce_len *= *extent; + } else { + return std::make_pair(-1, -1); + } + } else { + return std::make_pair(-1, -1); + } + } + return std::make_pair(cum_space_len, cum_reduce_len); +} + +bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // + const tir::StmtSRef& block_sref, // + int64_t max_parallel_extent, // + int64_t max_parallel_basic) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = tir::GetLoops(block_sref); + + // Cond 1. The block has no annotations + if (!block->annotations.empty()) { + return false; + } + + // Cond 2. The block has only one write buffer + if (block->writes.size() != 1) { + return false; + } + + // Cond 3. The block satisfies all the following properties + // - it is a reduction block; + // - it has trivial bindings; + // - it has not been tiled by multi-level tiling. + const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // + /*require_stage_pipeline=*/false, // + /*require_subtree_compact_dataflow=*/false); + if (!IsReductionBlock(self, block_sref, scope_sref) // + || !IsTrivialBinding(self, block_sref) // + || HasBeenMultiLevelTiled(block_sref)) { + return false; + } + + // Cond 4. Every the loop axis must be either spatial axis or reduction axis. + for (const tir::StmtSRef& loop_sref : loops) { + const tir::IterVarType& type = GetLoopIterType(loop_sref); + if (type != tir::kDataPar && type != tir::kCommReduce) { + return false; + } + } + + // Cond 5. Whether there is at least one reduction loop. + // Cond 6. The loops are continuous, and the body of the innermost loop is exactly the block. + bool has_reduction_loop = false; + for (size_t i = 0; i < loops.size(); ++i) { + // Cond 5. + if (GetLoopIterType(loops[i]) == tir::kCommReduce) { + has_reduction_loop = true; + } + + // Cond 6. + const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); + if (i < loops.size() - 1) { + const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); + if (loop_i->body.get() != loop_i1) { + return false; + } + } else { + const auto* block_realize = loop_i->body.as(); + if (!block_realize || block_realize->block.get() != block) { + return false; + } + } + } + if (!has_reduction_loop) { + return false; + } + + // Cond 7. Can successfully calculating the cumulative loop length. + int64_t cum_space_len, cum_reduce_len; + std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(self, block_sref); + if (cum_space_len == -1 || cum_reduce_len == -1) { + return false; + } + + // Cond 8. + if (NeedsMultiLevelTiling(self, block_sref)) { + // Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops. + return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent); + } else if (cum_reduce_len > 1) { + // Always try rfactor/cross-thread-reduction for other reduction blocks. + return cum_reduce_len > max_parallel_basic; + } else { + return false; + } +} + +bool HasOp(const Stmt& stmt, const Array& ops) { + std::unordered_set op_set; + op_set.reserve(ops.size()); + for (const Op& op : ops) { + op_set.insert(op.operator->()); + } + bool found = false; + PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { + if (found) { + return false; + } + if (const auto* call = obj.as()) { + if (op_set.count(call->op.operator->())) { + found = true; + } + } + return !found; + }); + return found; +} + +bool HasIfThenElse(const Stmt& stmt) { + bool has_branch = false; + auto f_visit = [&has_branch](const ObjectRef& obj) -> bool { + if (has_branch) { + // stop visiting + return false; + } + if (const auto* realize = obj.as()) { + // Case 1: BlockRealize + if (!is_one(realize->predicate)) { + has_branch = true; + } + } else if (obj->IsInstance() || obj->IsInstance()) { + // Case 2: IfThenElse / Select + has_branch = true; + } else if (const auto* call = obj.as()) { + // Case 3: Call + static const Op& op_if_then_else = Op::Get("tir.if_then_else"); + if (call->op.same_as(op_if_then_else)) { + has_branch = true; + } + } + return !has_branch; + }; + PreOrderVisit(stmt, f_visit); + return has_branch; +} + /******** Storage Scope ********/ void CheckStorageScope(const ScheduleState& self, String storage_scope) { diff --git a/src/tir/schedule/analysis/layout.cc b/src/tir/schedule/analysis/layout.cc new file mode 100644 index 000000000000..144b3a55a467 --- /dev/null +++ b/src/tir/schedule/analysis/layout.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Calculate the strides of the buffer + * \param buffer The buffer + * \return The strides + */ +Array GetStrides(const Buffer& buffer) { + if (!buffer->strides.empty()) { + ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); + return buffer->strides; + } + int ndim = buffer->shape.size(); + if (ndim == 0) { + return {}; + } + Array strides(ndim, PrimExpr{nullptr}); + PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); + for (int i = ndim - 1; i >= 0; --i) { + strides.Set(i, stride); + stride = stride * buffer->shape[i]; + } + return strides; +} + +/*! + * \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern + * to help decision making in layout transformation + */ +class SplitExprCollector { + public: + /*! + * \brief The corresponding IterSplitExpr, simplified for our case + * The pattern is `source // lower_factor % extent * scale` + */ + struct SplitExpr { + /*! \brief The source variable */ + Var source; + /*! \brief The lower factor of the split expression */ + int64_t lower_factor; + /*! \brief The extent of the split expression */ + int64_t extent; + }; + + /*! + * \brief Collect the split expressions in the indexing pattern + * \param index The indexing pattern + * \param input_iters The input iterators' domain + * \param predicate The predicate of the affine map + * \param require_bijective Whether the affine map is required to be bijective + * \param analyzer The analyzer + * \return The collected split expressions + */ + static std::vector Collect(const PrimExpr& index, + const Map& input_iters, // + const PrimExpr& predicate, // + bool require_bijective, // + arith::Analyzer* analyzer) { + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array iter_sum_exprs = arith::DetectIterMap( + {analyzer->Simplify(index)}, input_iters, predicate, require_bijective, analyzer, diag_ctx); + if (iter_sum_exprs.empty()) { + return {}; + } + ICHECK_EQ(iter_sum_exprs.size(), 1); + if (iter_sum_exprs[0]->args.size() == 0) { + return {}; + } + SplitExprCollector collector; + collector.Visit(iter_sum_exprs[0]); + if (collector.failed_) { + return {}; + } + return std::move(collector.exprs_); + } + + private: + void Visit(const arith::IterSplitExpr& expr) { + if (const auto* var = expr->source->source.as()) { + const int64_t* lower_factor = as_const_int(expr->lower_factor); + const int64_t* extent = as_const_int(expr->extent); + if (lower_factor == nullptr || extent == nullptr) { + failed_ = true; + return; + } + exprs_.push_back(SplitExpr{GetRef(var), *lower_factor, *extent}); + } else if (const auto* iter_sum_expr = expr->source->source.as()) { + Visit(GetRef(iter_sum_expr)); + } else { + ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); + } + } + + void Visit(const arith::IterSumExpr& expr) { + for (const arith::IterSplitExpr& arg : expr->args) { + Visit(arg); + } + } + + /*! \brief Whether the analysis failed */ + bool failed_ = false; + /*! \brief The collected split expressions */ + std::vector exprs_; +}; + +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer) { + int ndim = buffer->shape.size(); + int n_loops = loops.size(); + // Step 1. Collect the domains and indices of loop variables + Map input_iters; + std::unordered_map var2id; + var2id.reserve(n_loops); + for (int i = 0; i < n_loops; ++i) { + const For& loop = loops[i]; + input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + var2id.emplace(loop->loop_var.get(), i); + } + // Step 2. Calculate a functor that flattens a multi-dimensional index + auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( + const Array& indices) -> PrimExpr { + PrimExpr flatten_index = make_const(dtype, 0); + for (int i = 0; i < ndim; ++i) { + flatten_index = flatten_index + strides[i] * indices[i]; + } + return flatten_index; + }; + // Step 3. Detect the IterSplitExpr of the indexing pattern + std::vector split_exprs = SplitExprCollector::Collect( + /*index=*/f_flatten_index(indices), input_iters, predicate, + /*require_bijective=*/false, analyzer); + if (split_exprs.empty()) { + return NullOpt; + } + // Step 4. Sort the order of the split expressions + std::vector order(split_exprs.size(), 0); + std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; }); + std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool { + const SplitExprCollector::SplitExpr& a = split_exprs[_a]; + const SplitExprCollector::SplitExpr& b = split_exprs[_b]; + int a_var_id = var2id.at(a.source.get()); + int b_var_id = var2id.at(b.source.get()); + if (a_var_id != b_var_id) { + return a_var_id < b_var_id; + } + return a.lower_factor > b.lower_factor; + }); + // Step 5. Create the indexing mapping + auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), // + split_exprs = std::move(split_exprs), // + order = std::move(order), // + shape = buffer->shape, // + analyzer // + ](Array indices) -> Array { + ICHECK_EQ(indices.size(), shape.size()); + for (int i = 0, n = indices.size(); i < n; ++i) { + analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); + } + PrimExpr index = f_flatten_index({indices.begin(), indices.end()}); + int ndim = split_exprs.size(); + // Step 5.1. Split the flattened index according to `split_exprs` + std::vector split; + split.reserve(ndim); + for (int i = ndim - 1; i >= 0; --i) { + index = analyzer->Simplify(index); + int64_t extent = split_exprs[i].extent; + split.push_back(analyzer->Simplify(floormod(index, extent))); + index = floordiv(index, extent); + } + std::reverse(split.begin(), split.end()); + // Step 5.2. Reorder the indexing pattern according to `order` + Array results; + results.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + results.push_back(split[order[i]]); + } + return results; + }; + return IndexMap::FromFunc(ndim, f_alter_layout); +} + +TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap") + .set_body_typed([](Buffer buffer, Array indices, Array loops, + PrimExpr predicate) { + arith::Analyzer analyzer; + return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 65886daad014..afaa998e6c8d 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -18,8 +18,6 @@ */ #include "./concrete_schedule.h" -#include - namespace tvm { namespace tir { @@ -30,7 +28,7 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } @@ -214,7 +212,7 @@ Schedule ConcreteScheduleNode::Copy() const { void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { if (seed == -1) { - seed = std::random_device()(); + seed = support::LinearCongruentialEngine::DeviceRandom(); } support::LinearCongruentialEngine(&rand_state_).Seed(seed); } @@ -242,6 +240,15 @@ Array ConcreteScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int throw; } +LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV( + tir::SampleComputeLocation(state_, &this->rand_state_, this->GetSRef(block_rv), &decision)); + TVM_TIR_SCHEDULE_END("sample-compute-location", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -504,6 +511,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +/******** Schedule: Data movement ********/ + +BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -597,27 +628,46 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { } /******** Schedule: Blockize & Tensorize ********/ -/******** Schedule: Annotation ********/ +BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Blockize(state_, this->GetSRef(loop_rv)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_); + return CreateRV(result); +} -ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) { - if (ann_val.as()) { - return ann_val; - } - if (const auto* expr = ann_val.as()) { - ICHECK(!ann_val->IsInstance()) - << "TypeError: runtime::String is expected, but gets StringImm"; - return this->Get(GetRef(expr)); - } - LOG(FATAL) - << "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but " - << "gets: " << ann_val->GetTypeKey(); - throw; +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), intrin); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); +} + +void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name)); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); } void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val)); + if (const auto* str = ann_val.as()) { + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, GetRef(str)); + } else if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->Get(GetRef(expr))); + } else if (ann_val.as()) { + tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, ann_val); + } else { + LOG(FATAL) + << "TypeError: Only strings, integers, floats and ExprRVs are supported for now, but gets: " + << ann_val->GetTypeKey(); + throw; + } this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } @@ -632,8 +682,20 @@ void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_k void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key, const ObjectRef& ann_val) { TVM_TIR_SCHEDULE_BEGIN(); - tir::Annotate(state_, this->GetSRef(block_rv), ann_key, - this->CheckAndGetAnnotationValue(ann_val)); + if (const auto* str = ann_val.as()) { + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, GetRef(str)); + } else if (const auto* expr = ann_val.as()) { + ICHECK(!ann_val->IsInstance()) + << "TypeError: runtime::String is expected, but gets StringImm"; + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, this->Get(GetRef(expr))); + } else if (ann_val.as()) { + tir::Annotate(state_, this->GetSRef(block_rv), ann_key, ann_val); + } else { + LOG(FATAL) + << "TypeError: Only strings, integers, floats and ExprRVs are supported for now, but gets: " + << ann_val->GetTypeKey(); + throw; + } this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_); } @@ -645,6 +707,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_ TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_); } +/******** Schedule: Layout transformation ********/ +void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, + bool is_write_index, const IndexMap& index_map) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, is_write_index, index_map); + this->state_->DebugVerify(); + TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); +} + /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index d420728a9e3c..cacd8e389dff 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -86,6 +86,8 @@ class ConcreteScheduleNode : public ScheduleNode { Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; + LoopRV SampleComputeLocation(const BlockRV& block_rv, + Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; @@ -107,6 +109,11 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -121,12 +128,18 @@ class ConcreteScheduleNode : public ScheduleNode { int offset) override; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) override; + void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) override; + void Tensorize(const LoopRV& loop_rv, const String& intrin_name) override; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; + /******** Schedule: Layout transformation ********/ + void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index, + const IndexMap& index_map) override; /******** Schedule: Misc ********/ void EnterPostproc() override {} @@ -224,7 +237,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; } - const ObjectRef& obj = (*it).second; + ObjectRef obj = (*it).second; const auto* sref = obj.as(); if (sref == nullptr) { LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: " @@ -243,7 +256,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; } - const ObjectRef& obj = (*it).second; + ObjectRef obj = (*it).second; if (obj.same_as(inline_mark)) { return inline_mark; } diff --git a/src/tir/schedule/instruction.cc b/src/tir/schedule/instruction.cc index af721767c32f..cedba4b96095 100644 --- a/src/tir/schedule/instruction.cc +++ b/src/tir/schedule/instruction.cc @@ -21,6 +21,11 @@ namespace tvm { namespace tir { +bool InstructionKindNode::IsPostproc() const { + static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); + return this == inst_enter_postproc.get(); +} + Instruction::Instruction(InstructionKind kind, Array inputs, Array attrs, Array outputs) { ObjectPtr n = make_object(); diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index b0d4992989fb..71ee09ab6829 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -43,7 +43,7 @@ namespace tir { * * // Convertible to `InstructionKindNode::FInstructionApply` * static Array ApplyToSchedule( - * const tir::Schedule& sch, + * const Schedule& sch, * const Array& inputs, * const Array& attrs, * const Optional& decision); @@ -193,10 +193,12 @@ class PythonAPICall { * \param method_name The name of the schedule API to be called */ explicit PythonAPICall(String method_name) : method_name_(method_name), output_(NullOpt) {} - /*! \brief Add an intger input */ + /*! \brief Add an integer input */ inline void Input(String arg_name, int arg); - /*! \brief Add an intger input */ + /*! \brief Add an integer input */ inline void Input(String arg_name, int64_t arg); + /*! \brief Add a bool input */ + inline void Input(String arg_name, bool arg); /*! \brief Add a double input */ inline void Input(String arg_name, double arg); /*! \brief Add an input random variable */ @@ -462,6 +464,17 @@ void PythonAPICall::Input(String arg_name, int64_t arg) { args_.push_back(std::to_string(arg)); } +void PythonAPICall::Input(String arg_name, bool arg) { + static const char* true_str = "True"; + static const char* false_str = "False"; + arg_names_.emplace_back(std::move(arg_name)); + if (arg) { + args_.push_back(true_str); + } else { + args_.push_back(false_str); + } +} + void PythonAPICall::Input(String arg_name, double arg) { arg_names_.emplace_back(std::move(arg_name)); std::ostringstream os; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 212e53aa500f..a79b5ec3eaba 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -22,6 +22,7 @@ #include #include +#include #include namespace tvm { @@ -30,12 +31,22 @@ namespace tir { /******** Schedule: Sampling ********/ /*! * \brief Sample a random integer from a given range. + * \param rand_state The pointer to schedule's random state * \param min_inclusive The minimum value of the range, inclusive. * \param max_exclusive The maximum value of the range, exclusive. * \return The random integer sampled in the given range. */ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_state, int32_t min_inclusive, int32_t max_exclusive); +/*! + * \brief Sample k random integers from given range without replacement, i.e, no duplication. + * \param rand_state The pointer to schedule's random state + * \param n The range is defined as 0 to n-1. + * \param k The total number of samples. + * \return The randomly selected samples from the n candidates. + */ +std::vector SampleWithoutReplacement( + support::LinearCongruentialEngine::TRandState* rand_state, int32_t n, int32_t k); /*! * \brief Sample once category from candidates according to the probability weights. * \param rand_state The pointer to schedule's random state @@ -47,6 +58,14 @@ TVM_DLL int32_t SampleInt(support::LinearCongruentialEngine::TRandState* rand_st TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); +/*! + * \brief Create a sampling function that does multinomial sampling. + * \param rand_state The random state. + * \param weights The weights for multinomial sampling. + * \return The multinomial sampling function. + */ +TVM_DLL std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights); /*! * \brief Sample the factors to perfect tile a specific loop * \param rand_state The random state @@ -81,6 +100,17 @@ TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, Optional>* decision); +/*! + * \brief Sample a compute-at location of the given block + * \param self The schedule state + * \param rand_state The random state + * \param block_sref The sref of the block whose compute-at location is to be sampled + * \param decision The sampling decision + * \return The sampled loop where the input block is to be computed at + */ +TVM_DLL tir::StmtSRef SampleComputeLocation( + tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, + const tir::StmtSRef& block_sref, Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! @@ -224,6 +254,15 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); + +/******** Schedule: Data movement ********/ + +TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); + +TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the @@ -350,6 +389,11 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer const String& storage_scope); /******** Schedule: Blockize & Tensorize ********/ + +TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref); +TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& loop_sref, + const TensorIntrin& intrinsic); + /******** Schedule: Annotation ********/ /*! * \brief Annotate a block/loop with a key value pair @@ -360,6 +404,23 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer */ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, const ObjectRef& ann_val); + +/******** Schedule: Layout transformation ********/ +/*! + * \brief Apply a transformation represented by IndexMap to buffer + * \details The indices and the access region to the target buffer is transformed by the given + * index_map. The index_map is also used to infer the new shape of the buffer. Buffer must be + * one of the parameter of the function, or allocated in some blocks (it cannot be a buffer + * subregion created via match_buffer). + * \param self The state of the schedule + * \param block_sref The block sref that accesses the target buffer. + * \param buffer_index The index of the buffer in block's read or write region. + * \param is_write_index Whether the buffer_index is the index of the block's write region. + * \param index_map The transformation to apply. + */ +TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + bool is_write_index, const IndexMap& index_map); + /*! * \brief Unannotate a block/loop's annotation with key ann_key * \param self The state of the schedule @@ -367,7 +428,6 @@ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& an * \param ann_key The annotation key */ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); - /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 0c79d55fcd86..09b7a47e8ee8 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -127,8 +127,7 @@ struct AnnotateTraits : public UnpackedInstTraits { return py.Str(); } - template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct UnpackedInstTraits; }; struct UnannotateTraits : public UnpackedInstTraits { @@ -159,8 +158,7 @@ struct UnannotateTraits : public UnpackedInstTraits { return py.Str(); } - template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 181e5a6cfa69..d9b3dda9e9e3 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -64,44 +64,6 @@ class StorageAlignAxisOutOfRangeError : public ScheduleError { int axis_; }; -/*! - * \brief Find the defining site of the buffer in the given block and its ancestors - * \param block_sref The block sref - * \param buffer The buffer - * \return The defining site of the buffer and whether the buffer is allocated (otherwise the - * buffer is from match_buffer). - */ -std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, - const Buffer& buffer) { - // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or - // match_buffers. - const StmtSRefNode* defining_site_sref = block_sref.get(); - while (defining_site_sref != nullptr) { - const auto* block = defining_site_sref->StmtAs(); - // If this sref is not a block sref, skip it. - if (block == nullptr) { - defining_site_sref = defining_site_sref->parent; - continue; - } - // Try to find the buffer in `allloc_buffers` - for (const Buffer& alloc_buffer : block->alloc_buffers) { - if (buffer.same_as(alloc_buffer)) { - return {GetRef(defining_site_sref), true}; - } - } - // We do not allow the buffer being defined in `match_buffer`. - for (const MatchBufferRegion match_buffer : block->match_buffers) { - if (buffer.same_as(match_buffer)) { - return {GetRef(defining_site_sref), false}; - } - } - defining_site_sref = defining_site_sref->parent; - } - // If we cannot find the defining site block, it means that the buffer must be in the function's - // buffer_map, which isn't an intermediate buffer. - return {NullOpt, false}; -} - class NonAllocatedBufferError : public ScheduleError { public: explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc new file mode 100644 index 000000000000..bba95e594b4d --- /dev/null +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -0,0 +1,1014 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../../arith/pattern_match.h" +#include "../../ir/functor_common.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +bool CheckOneLine(const Stmt& s) { + bool legal = true, meet_block = false; + PostOrderVisit(s, [&legal, &meet_block](const ObjectRef& obj) { + if (obj->IsInstance() && !meet_block) { + legal = false; + } else if (obj->IsInstance()) { + meet_block = true; + } + }); + return legal; +} + +Block GetRootBlock(StmtSRef sref) { + const StmtSRefNode* p_sref = sref.get(); + while (p_sref->parent != nullptr) { + p_sref = p_sref->parent; + } + const BlockNode* root_block = TVM_SREF_TO_BLOCK(root_block, GetRef(p_sref)); + return GetRef(root_block); +} + +void RecalculateCachedFlags(ScheduleStateNode* self) { + ScheduleState new_state(self->mod); + for (const auto& kv : new_state->stmt2ref) { + const StmtNode* stmt = kv.first; + const StmtSRef& new_sref = kv.second; + if (stmt->IsInstance() || !self->stmt2ref.count(stmt)) { + continue; + } + const BlockInfo& new_block_info = new_state->block_info.at(new_sref); + const StmtSRef& old_sref = self->stmt2ref.at(stmt); + BlockInfo& old_block_info = self->block_info.at(old_sref); + old_block_info.affine_binding = new_block_info.affine_binding; + old_block_info.region_cover = new_block_info.region_cover; + old_block_info.scope->stage_pipeline = new_block_info.scope->stage_pipeline; + } +} + +void UpdateScope(ScheduleState self, const StmtSRef& block_sref) { + BlockScope scope(tir::GetChildBlocks(self, block_sref)); + // The caller is responsible for correcting the flags + bool affine_binding = false; + bool region_cover = false; + // TODO(@Wuwei): stage_pipeline + self->block_info[block_sref] = BlockInfo(std::move(scope), affine_binding, region_cover); +} + +bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) { + if (n.same_as(other)) return true; + if (n->type_index() != other->type_index()) return false; + bool equal = StmtComparator::VisitStmt(n, other); + if (!equal && assert_mode_) + LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" << other; + return equal; +} + +bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { + const auto* rhs = other.as(); + if (!DefEqual(op->loop_var, rhs->loop_var)) return false; + if (!VisitExpr(op->min, rhs->min)) return false; + if (!VisitExpr(op->extent, rhs->extent)) return false; + if (!VisitStmt(op->body, rhs->body)) return false; + if (op->kind != rhs->kind) return false; + if (op->thread_binding.defined() ^ rhs->thread_binding.defined()) return false; + if (op->thread_binding.defined() && + !VisitExpr(op->thread_binding.value(), rhs->thread_binding.value())) + return false; + return CompareAnnotationMap(op->annotations, rhs->annotations); +} + +bool TensorizeComparator::VisitStmt_(const SeqStmtNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareArray(op->seq, rhs->seq, &TensorizeComparator::VisitStmt); +} + +bool TensorizeComparator::VisitStmt_(const BufferStoreNode* op, const Stmt& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Skip Compare binding values if the block is scope block (the outermost one). + if (!is_scope_block) { + size_t offset = op->iter_values.size() - rhs->iter_values.size(); + if (rhs->iter_values.size() > op->iter_values.size()) return false; + if (is_inner_block) { + // weak pattern matching for the inner block (the son of the scope block) + // where the pattern is v + iter <=> expr + iter + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + PrimExpr lhs_expr, rhs_expr; + Optional lhs_iter, rhs_iter; + auto detect = [](const PrimExpr& binding) -> std::pair> { + arith::PVar expr; + arith::PVar iter; + if (iter.Match(binding)) { + return std::make_pair(0, iter.Eval()); + } else if ((expr + iter).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else if ((iter + expr).Match(binding)) { + return std::make_pair(expr.Eval(), iter.Eval()); + } else { + return std::make_pair(expr.Eval(), NullOpt); + } + }; + std::tie(lhs_expr, lhs_iter) = detect(op->iter_values[i + offset]); + std::tie(rhs_expr, rhs_iter) = detect(rhs->iter_values[i]); + CHECK((lhs_iter && rhs_iter) || (!lhs_iter && !rhs_iter)) << "Incompatible binding"; + if (lhs_iter) VisitExpr(lhs_iter.value(), rhs_iter.value()); + if (is_zero(rhs_expr)) { + CHECK(is_zero(lhs_expr)) << "Incompatible binding"; + } else { + const auto* bv = rhs_expr.as(); + if (!bv) { + VisitExpr(lhs_expr, rhs_expr); + } else { + auto it = equal_map_.find(GetRef(bv)); + if (it == equal_map_.end()) { + equal_map_[GetRef(bv)] = lhs_expr; + } else { + CHECK(it->second->IsInstance()); + VisitExpr(lhs_expr, Downcast(it->second)); + } + } + } + } + } else { + for (size_t i = 0; i < rhs->iter_values.size(); ++i) { + if (!VisitExpr(op->iter_values[i + offset], rhs->iter_values[i])) return false; + } + const Block& block = op->block; + for (size_t i = 0; i < offset; ++i) { + Var block_var = Downcast(op->iter_values[i]); + auto it = equal_map_.find(block_var); + equal_map_[block->iter_vars[i]->var] = (it == equal_map_.end() ? block_var : it->second); + } + } + } + + return VisitExpr(op->predicate, rhs->predicate) && VisitStmt(op->block, rhs->block); +} + +bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { + const auto* rhs = other.as(); + // Check block equal + // All iter var and buffer region should matches including the order + + // Check iterVar + // need to use DefEqual to remap vars + // Note: + // We only compare the inner most several axis + if (op->iter_vars.size() < rhs->iter_vars.size()) return false; + + size_t offset = op->iter_vars.size() - rhs->iter_vars.size(); + for (size_t i = 0; i < rhs->iter_vars.size(); ++i) { + auto lhs_var = op->iter_vars[i + offset], rhs_var = rhs->iter_vars[i]; + // Skip iter dom + if (!DefEqual(lhs_var->var, rhs_var->var)) return false; + if (lhs_var->iter_type != rhs_var->iter_type) return false; + } + + for (size_t i = 0; i < offset; ++i) { + if (is_scope_block) { + extra_block_vars_.push_back(op->iter_vars[i]); + } + } + + if (!is_scope_block) { + if (!CompareArray(op->writes, rhs->writes, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareArray(op->reads, rhs->reads, &TensorizeComparator::CompareBufferRegion)) { + return false; + } + if (!CompareAnnotationMap(op->annotations, rhs->annotations)) { + return false; + } + if (!CompareArray(op->alloc_buffers, rhs->alloc_buffers, &TensorizeComparator::CompareBuffer)) { + return false; + } + } + if (!is_scope_block) is_inner_block = false; + is_scope_block = false; + return VisitStmt(op->body, rhs->body); +} + +// Exprs +#define TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OpName) \ + bool TensorizeComparator::VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AddNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(SubNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MulNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(DivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(ModNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(EQNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(NENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(LENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GTNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(GENode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(AndNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(OrNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MinNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(MaxNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorDivNode); +TVM_DECLARE_TENSORIZE_COMPARATOR_BINOP(FloorModNode); + +bool TensorizeComparator::VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && op->value == rhs->value; +} + +bool TensorizeComparator::VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareType(op->dtype, rhs->dtype) && VisitExpr(op->value, rhs->value); +} + +bool TensorizeComparator::VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (!CompareType(op->dtype, rhs->dtype)) return false; + auto it = equal_map_.find(lhs); + return it != equal_map_.end() && it->second.same_as(other); +} + +bool TensorizeComparator::VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return CompareBufferAccess(op, rhs); +} + +bool TensorizeComparator::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_.find(lhs); + // If there is already a mapping + if (it != equal_map_.end()) return it->second.same_as(rhs); + equal_map_[lhs] = rhs; + return true; +} + +bool TensorizeComparator::CompareAnnotation(const std::pair& lhs, + const std::pair& rhs) { + if (lhs.first != rhs.first) return false; + if (!lhs.second.same_as(rhs.second)) return false; + return VisitExpr(Downcast(lhs.second), Downcast(rhs.second)); +} + +bool TensorizeComparator::CompareAnnotationMap(const Map& lhs, + const Map& rhs) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + + auto sort_map = + [](const Map& map) -> std::vector> { + std::vector> ret; + ret.reserve(map.size()); + for (const auto& pair : map) { + ret.emplace_back(pair); + } + sort(ret.begin(), ret.end()); + return ret; + }; + + auto lhs_array = sort_map(lhs), rhs_array = sort_map(rhs); + + for (size_t i = 0; i < lhs.size(); ++i) { + if (!CompareAnnotation(lhs_array[i], rhs_array[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareBuffer(const Buffer& lhs, const Buffer& rhs) { + if (lhs.same_as(rhs)) return true; + // Remap both buffer itself and buffer data + // Skip buffer shape + bool equal = DefEqual(lhs, rhs) && DefEqual(lhs->data, rhs->data) && + CompareType(lhs->dtype, rhs->dtype) && lhs.scope() == rhs.scope(); + if (equal) { + rhs_buffer_map_[rhs] = lhs; + } else if (assert_mode_) { + LOG(FATAL) << "Buffers are not matching between:" << lhs << " and " << rhs; + } + return equal; +} + +bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs) { + // Only for block region declaration + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + // Number of indices in desc_block must be smaller than it in AST + if (rhs->region.size() > lhs->region.size()) return false; + + std::vector lhs_region; + for (const auto& range : lhs->region) { + lhs_region.push_back(Range::FromMinExtent(range->min, range->extent)); + } + // special judge size 1 buffer + if (rhs->region.size() == 1 && is_zero(rhs->region[0]->min) && is_one(rhs->region[0]->extent)) { + lhs_region.push_back(Range::FromMinExtent(0, 1)); + } + size_t offset = lhs_region.size() - rhs->region.size(); + // initialize buffer indices + bool need_update = false; + if (!buffer_indices_.count(lhs->buffer)) { + need_update = true; + buffer_indices_[lhs->buffer] = std::vector(); + } else { + if (offset != buffer_indices_[lhs->buffer].size()) return false; + } + std::vector& indices = buffer_indices_[lhs->buffer]; + for (size_t i = 0; i < offset; ++i) { + const Range& range = lhs_region[i]; + // High-dim region must be element-wise + if (!is_one(range->extent)) return false; + if (need_update) { + indices.push_back(range->min); + } else { + // The order matters since we only map inner block_var to outside block_var + if (!VisitExpr(range->min, indices[i])) return false; + } + } + for (size_t i = 0; i < rhs->region.size(); ++i) { + if (!CompareRange(lhs_region[i + offset], rhs->region[i])) return false; + } + return true; +} + +// Only for BufferStoreNode and BufferLoadNode +template +bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { + if (!CompareBuffer(lhs->buffer, rhs->buffer)) return false; + + if (rhs->indices.size() > lhs->indices.size()) return false; + // special judge size 1 buffer + if (rhs->indices.size() == 1 && is_zero(rhs->indices[0])) return true; + // otherwise + size_t offset = lhs->indices.size() - rhs->indices.size(); + for (size_t i = 0; i < rhs->indices.size(); ++i) { + if (!VisitExpr(lhs->indices[i + offset], rhs->indices[i])) return false; + } + return true; +} + +template +bool TensorizeComparator::CompareArray(const Array& lhs, const Array& rhs, F cmp) { + if (lhs.same_as(rhs)) return true; + if (lhs.size() != rhs.size()) return false; + for (size_t i = 0; i < lhs.size(); ++i) { + if (!(this->*cmp)(lhs[i], rhs[i])) return false; + } + return true; +} + +bool TensorizeComparator::CompareRange(const Range& lhs, const Range& rhs) { + return VisitExpr(lhs->min, rhs->min) && VisitExpr(lhs->extent, rhs->extent); +} + +bool TensorizeComparator::CompareType(const DataType& lhs, const DataType& rhs) { + if (lhs == rhs) return true; + return lhs.code() == rhs.code() && lhs.bits() == rhs.bits() && lhs.lanes() == rhs.lanes(); +} + +// Deep comparison to check if two IR graph are equivalent +bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { + bool equal = (n->type_index() == other->type_index()) && ExprComparator::VisitExpr(n, other); + if (!equal && assert_mode_) + LOG(FATAL) << "Exprs are not matching between:" << n << " and " << other; + return equal; +} + +Array> TrivialSubspaceDivision(const Array& iter_vars, + const Array& bindings, + const std::vector& outer_loops, + const std::vector& inner_loops, + const PrimExpr& predicate) { + if (!is_one(predicate)) return {}; + std::vector> res; + std::unordered_set outer_loop_vars; + std::unordered_set inner_loop_vars; + for (const Var& var : outer_loops) { + outer_loop_vars.insert(var.get()); + } + for (const Var& var : inner_loops) { + inner_loop_vars.insert(var.get()); + } + for (size_t i = 0; i < bindings.size(); ++i) { + bool outer = UsesVar( + bindings[i], [&outer_loop_vars](const VarNode* var) { return outer_loop_vars.count(var); }); + bool inner = UsesVar( + bindings[i], [&inner_loop_vars](const VarNode* var) { return inner_loop_vars.count(var); }); + bool is_var = bindings[i]->IsInstance(); + if (outer && !inner) { + arith::IterMark outer{nullptr}; + if (is_var) { + outer = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + outer = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + arith::IterMark inner(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else if (inner && !outer) { + arith::IterMark inner{nullptr}; + if (is_var) { + inner = arith::IterMark( + arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), + iter_vars[i]->dom->extent); + } else { + inner = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); + } + arith::IterMark outer(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else if (!outer && !inner) { + arith::IterMark outer(arith::IterSumExpr({}, 0), 1); + arith::IterMark inner(arith::IterSumExpr({}, 0), 1); + res.push_back(Array({outer, inner})); + } else { + return {}; + } + } + res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), + arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); + return res; +} + +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + /*! + * Check: + * - The sub AST is one-line with only one block + * + * Mutate: + * - extra block var from the only block + * - Update block binding + */ + const auto* loop = loop_sref->StmtAs(); + CHECK(loop) << "TypeError: Only support blockize a loop for now, but get type: " + << loop_sref->stmt->GetTypeKey(); + // check there exists no SeqStmt under loop + CHECK(CheckOneLine(GetRef(loop))) << "ValueError: Only one line subtree can be blockize"; + // get the inner Block, BlockRealize and StmtSRef + Array child_blocks = GetChildBlocks(self, loop_sref); + CHECK_EQ(child_blocks.size(), 1) << "ValueError: Only one line subtree can be blockize"; + StmtSRef block_sref = child_blocks[0]; + BlockRealize block_realize = GetBlockRealize(self, block_sref); + Block block = block_realize->block; + // collect loops inside/outside loop_sref + std::vector outer_loops, inner_loops; + std::vector outer_iters, inner_iters; + std::unordered_map iters; + bool inner = true; + for (StmtSRef current_sref = block_sref;;) { + current_sref = GetRef(current_sref->parent); + if (!current_sref.defined()) break; + const auto* current_loop = current_sref->StmtAs(); + if (!current_loop) break; + if (inner) { + inner_loops.push_back(current_loop); + inner_iters.push_back(current_loop->loop_var); + } else { + outer_loops.push_back(current_loop); + outer_iters.push_back(current_loop->loop_var); + } + iters[current_loop->loop_var] = Range::FromMinExtent(current_loop->min, current_loop->extent); + if (current_sref == loop_sref) inner = false; + } + arith::Analyzer analyzer; + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array> division = + arith::SubspaceDivide(block_realize->iter_values, iters, inner_iters, + block_realize->predicate, false, &analyzer, diag_ctx); + if (division.empty()) { + // It is possible to blockize if we can not do perfect subspace division if we can divide + // the block var bindings into two categories + // 1. The binding covers no inner loop var + // 2. The binding covers only inner loop vars + division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, outer_iters, + inner_iters, block_realize->predicate); + } + CHECK(!division.empty()) << "ValueError: The bindings of the block below can not be blockized"; + // Generate a new inner block + Array inner_block_vars, outer_block_vars; + Array inner_bindings, outer_bindings; + std::unordered_map block_var_no; + std::unordered_map bv_iters; + for (size_t i = 0; i < block->iter_vars.size(); ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const arith::IterMapExprNode* outer_binding = + division[i][0]->source.as(); + const arith::IterMapExprNode* inner_binding = + division[i][1]->source.as(); + ICHECK(outer_binding); + ICHECK(inner_binding); + if (is_one(division[i][1]->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_block_vars.push_back(iter_var); + // bv_iters[iter_var->var] = Range::FromMinExtent(0, division[i][0]->extent); + } else { + const IterVar outer_var(Range::FromMinExtent(0, division[i][0]->extent), + iter_var->var.copy_with_suffix("o"), iter_var->iter_type); + outer_bindings.push_back( + arith::NormalizeIterMapToExpr(GetRef(outer_binding))); + outer_block_vars.push_back(outer_var); + // generate a new iter var for outer block + PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; + if (const auto* op = division[i][1]->source.as()) { + base = base + op->base; + inner_bindings.push_back(base + + arith::NormalizeIterMapToExpr(arith::IterSumExpr(op->args, 0))); + } else { + inner_bindings.push_back( + base + arith::NormalizeIterMapToExpr(GetRef(inner_binding))); + } + inner_block_vars.push_back(iter_var); + bv_iters[iter_var->var] = Range::FromMinExtent(base, division[i][1]->extent); + } + block_var_no[iter_var->var] = i; + } + Block inner_block = block; + inner_block.CopyOnWrite()->iter_vars = inner_block_vars; + inner_block.CopyOnWrite()->init = NullOpt; + BlockRealize inner_br = block_realize; + inner_br.CopyOnWrite()->iter_values = inner_bindings; + inner_br.CopyOnWrite()->predicate = division.back()[1]->extent; + inner_br.CopyOnWrite()->block = inner_block; + // Regenerate inner_loops + Stmt body = inner_br; + for (const auto& inner_loop : inner_loops) { + auto loop_node = make_object(*inner_loop); + loop_node->body = body; + body = For(loop_node); + } + // Regenerate init for outer block + Optional new_init = NullOpt; + if (block->init.defined()) { + std::vector init_loops; + std::vector init_block_vars; + std::vector init_block_vars_copy; + std::vector init_bindings; + std::unordered_map binding_replace_map; + std::unordered_map bv_replace_map; + std::unordered_map new_block_vars2old_index; + for (size_t i = 0; i < inner_block_vars.size(); ++i) { + if (inner_block_vars[i]->iter_type == IterVarType::kDataPar && + UsesVar(block->init.value(), + [v = inner_block_vars[i]->var](const VarNode* var) { return var == v.get(); })) { + // copy init block vars and ignore reduce block vars + init_block_vars.push_back(i); + IterVar init_block_var = inner_block_vars[i]; + init_block_var.CopyOnWrite()->var = inner_block_vars[i]->var.copy_with_suffix("_init"); + init_block_vars_copy.push_back(init_block_var); + bv_replace_map[inner_block_vars[i]->var] = init_block_var->var; + new_block_vars2old_index[init_block_var.get()] = i; + } + } + for (const ForNode* inner_loop : inner_loops) { + for (size_t i = 0; i < init_block_vars.size(); ++i) { + if (UsesVar(inner_bindings[new_block_vars2old_index[init_block_vars_copy[i].get()]], + [v = inner_loop->loop_var](const VarNode* var) { return var == v.get(); })) { + // copy loops related to init block vars + For init_loop = GetRef(inner_loop); + init_loop.CopyOnWrite()->loop_var = inner_loop->loop_var.copy_with_suffix(""); + // replace loop vars with copied loop vars + binding_replace_map[inner_loop->loop_var] = init_loop->loop_var; + init_loops.push_back(init_loop); + break; + } + } + } + for (size_t i = 0; i < init_block_vars.size(); ++i) { + init_bindings.push_back(Substitute(inner_bindings[init_block_vars[i]], binding_replace_map)); + } + new_init = Substitute(Block(/*iter_vars=*/init_block_vars_copy, // + /*reads=*/{}, // + /*writes=*/block->writes, // + /*name_hint=*/block->name_hint + "_init", // + /*body=*/block->init.value(), // + /*init=*/NullOpt), + bv_replace_map); + new_init = + BlockRealize(init_bindings, division.back()[1]->extent, Downcast(new_init.value())); + for (const auto& init_loop : init_loops) { + For new_init_loop = init_loop; + new_init_loop.CopyOnWrite()->body = new_init.value(); + new_init = new_init_loop; + } + } + // Calculate outer block's IO region + auto rewrite_range = [&](const Range& range) -> Range { + const Array& res = + arith::DetectIterMap({range->min}, bv_iters, true, false, &analyzer, diag_ctx); + ICHECK_EQ(res.size(), 1); + const arith::IterSumExpr& normalized_expr = res[0]; + PrimExpr extent = 1; + if (normalized_expr->args.size() == 1) { + CHECK(analyzer.CanProve(normalized_expr->args[0]->scale - range->extent == 0)); + extent = normalized_expr->args[0]->extent; + } + return Range::FromMinExtent(normalized_expr->base, extent * range->extent); + }; + std::vector reads, writes; + auto rewrite_region = [&](std::vector* regions, Array old_regions) { + for (auto buffer_region : old_regions) { + std::vector region; + for (const auto& range : buffer_region->region) { + region.push_back(rewrite_range(range)); + } + (*regions).emplace_back(buffer_region->buffer, region); + } + }; + rewrite_region(&reads, block->reads); + rewrite_region(&writes, block->writes); + // Generate a new outer block + auto outer_block = Block(/*iter_vars=*/outer_block_vars, // + /*reads=*/reads, // + /*writes=*/writes, // + /*name_hint=*/"blockized_" + block->name_hint, // + /*body=*/std::move(body), // + /*init=*/new_init); + auto outer_realize = BlockRealize(outer_bindings, division.back()[0]->extent, outer_block); + + self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + { + StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_compact_dataflow*/ false); + UpdateScope(self, scope_sref); + } + RecalculateCachedFlags(self.operator->()); + + // } + // TODO(@wuwei): fix affine flags + // self->Replace(loop_sref, outer_realize, {{block, inner_block}}); + // { + // StmtSRef block_sref = self->stmt2ref.at(inner_block.get()); + // UpdateAffineFlag(self, block_sref); + // } + // { + // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + // StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + // /*require_compact_dataflow*/false); + // UpdateScope(self, scope_sref); + // UpdateAffineFlag(self, scope_sref); + // } + // { + // StmtSRef block_sref = self->stmt2ref.at(outer_block.get()); + // UpdateScope(self, block_sref); + // UpdateAffineFlag(self, block_sref); + // } + + // // Check loop binding + + // { + // struct BindingValidator : public StmtVisitor { + // void VisitStmt_(const BlockRealizeNode* realize) final { + // StmtSRef& sref = self->stmt2ref.at(realize->block.get()); + // UpdateAffineFlag(self, sref); + // VisitStmt(realize->block->body); + // } + // ScheduleState self; + // }; + // BindingValidator validator; + // validator.self = self; + // const PrimFuncNode* func = GetRootPrimFunc(self->mod, GetRootBlock(loop_sref).get(), + // nullptr); validator(func->body); + // } + return self->stmt2ref.at(outer_block.get()); +} + +// Stmts + +void BufferRemap(const TensorIntrin& intrinsic, + std::unordered_map* buffer_map) { + ICHECK_EQ(intrinsic->description->params.size(), intrinsic->implementation->params.size()); + for (size_t i = 0; i < intrinsic->description->params.size(); ++i) { + const auto& lhs_var = intrinsic->description->params[i]; + const auto& lhs_buffer = intrinsic->description->buffer_map[lhs_var]; + const auto& rhs_var = intrinsic->implementation->params[i]; + const auto& rhs_buffer = intrinsic->implementation->buffer_map[rhs_var]; + (*buffer_map)[rhs_buffer] = lhs_buffer; + } +} + +// Replace buffer with its data, element_offset +class BufferReplacer : public StmtExprMutator { + public: + explicit BufferReplacer( + const std::unordered_map& buffer_map, + const std::unordered_map& var_map, + std::vector&& extra_block_vars, + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_indices) + : buffer_map_(buffer_map), + var_map_(var_map), + extra_block_vars_(std::move(extra_block_vars)), + buffer_indices_(buffer_indices) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + auto s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + CHECK(op); + auto it = buffer_map_.find(op->buffer); + if (it != buffer_map_.end()) { + auto n = CopyOnWrite(op); + n->buffer = it->second; + auto it2 = buffer_indices_.find(n->buffer); + CHECK(it2 != buffer_indices_.end()); + n->indices.insert(n->indices.begin(), it2->second.begin(), it2->second.end()); + return Stmt(n); + } else { + return GetRef(op); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + auto s = StmtExprMutator::VisitExpr_(op); + op = s.as(); + CHECK(op); + auto it = buffer_map_.find(op->buffer); + if (it != buffer_map_.end()) { + auto n = make_object(*op); + n->buffer = it->second; + auto it2 = buffer_indices_.find(n->buffer); + CHECK(it2 != buffer_indices_.end()); + n->indices.insert(n->indices.begin(), it2->second.begin(), it2->second.end()); + return PrimExpr(n); + } else { + return GetRef(op); + } + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_map_.find(op); + if (it != var_map_.end()) { + return GetRef(it->second); + } else { + auto it2 = block_var_map_.find(op); + if (it2 != block_var_map_.find(op)) { + return GetRef(it2->second); + } else { + return GetRef(op); + } + } + } + + Stmt VisitStmt_(const BlockNode* op) final { + std::vector extra_block_var; + std::unordered_map block_var_map; + for (const auto& iter_var : extra_block_vars_) { + auto n = runtime::make_object(*(iter_var.get())); + IterVar block_var(n); + extra_block_var.push_back(block_var); + block_var_map[iter_var->var.get()] = block_var->var.get(); + } + std::swap(block_var_map, block_var_map_); + auto s = StmtExprMutator::VisitStmt_(op); + op = s.as(); + CHECK(op); + + auto iter_vars = op->iter_vars; + iter_vars.insert(iter_vars.begin(), extra_block_var.begin(), extra_block_var.end()); + auto reads = UpdateBufferViaMap(op->reads); + auto writes = UpdateBufferViaMap(op->writes); + + std::swap(block_var_map, block_var_map_); + + if (reads.same_as(op->reads) && writes.same_as(op->writes) && + iter_vars.same_as(op->iter_vars)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->iter_vars = std::move(iter_vars); + return Block(n); + } + } + + private: + const std::unordered_map& buffer_map_; + const std::unordered_map& var_map_; + std::unordered_map block_var_map_; + const std::vector& extra_block_vars_; + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_indices_; + + Array UpdateBufferViaMap(const Array& buffer_regions) { + auto f_mutate = [this](const BufferRegion& buffer_region) { + auto it = buffer_map_.find(buffer_region->buffer); + if (it != buffer_map_.end()) { + auto n = make_object(*buffer_region.get()); + n->buffer = it->second; + auto it2 = buffer_indices_.find(n->buffer); + if (it2 != buffer_indices_.end()) { + Region region; + for (const auto& min : it2->second) { + region.push_back(Range::FromMinExtent(VisitExpr(min), 1)); + } + n->region.insert(n->region.begin(), region.begin(), region.end()); + } + while (n->region.size() > n->buffer->shape.size()) { + const Range& range = n->region.back(); + ICHECK(is_one(range->extent) && is_zero(range->min)); + n->region.pop_back(); + } + return BufferRegion(n); + } else { + return buffer_region; + } + }; + return MutateArray(buffer_regions, f_mutate); + } +}; + +void Tensorize(ScheduleState self, const StmtSRef& loop_sref, const TensorIntrin& intrinsic) { + /*! + * Check: + * - Check buffer binding, including type, alignment, shape and etc. + * - Check the sub AST is equal to the description function. + * + * Mutate: + * - Blockize the sub AST (please refer blockize for details) + * - Bind buffers + * - Mutate implement function with buffer binding + * - Replace the sub tree with the mutated function. + */ + const auto* loop = loop_sref->StmtAs(); + CHECK(loop) << "Only support tensorize a loop for now"; + + const auto* desc_block_realize = + Downcast(intrinsic->description->body)->block->body.as(); + const Block& desc_block = desc_block_realize->block; + const auto* impl_block_realize = + Downcast(intrinsic->implementation->body)->block->body.as(); + Block impl_block = impl_block_realize->block; + + const StmtSRef& block_sref = Blockize(self, loop_sref); + const BlockRealize& block_realize = GetBlockRealize(self, block_sref); + + TensorizeComparator comparator; + bool equal = comparator.VisitStmt(block_realize, GetRef(desc_block_realize)); + CHECK(equal) << "The AST subtree does not match intrinsic description"; + // Map from intrinsic func buffer to description func buffer + std::unordered_map intrin_buffer_map; + BufferRemap(intrinsic, &intrin_buffer_map); + // Map form intrinsic func buffer to current AST buffer + std::unordered_map buffer_map; + for (const auto& pair : intrin_buffer_map) { + auto it = comparator.rhs_buffer_map_.find(pair.second); + CHECK(it != comparator.rhs_buffer_map_.end()); + buffer_map[pair.first] = it->second; + } + // Build Var map, which is the map from intrin buffer data to AST buffer data + std::unordered_map var_map; + auto update_var_map = [&var_map](const PrimExpr& lhs, const PrimExpr& rhs) { + if (const auto* var = lhs.as()) { + var_map[var] = rhs.get(); + } + }; + for (const auto& pair : buffer_map) { + update_var_map(pair.first->data, pair.second->data); + } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; + for (const auto& read : impl_block->reads) { + buffer_region_map.emplace(read->buffer, read->region); + } + for (const auto& write : impl_block->writes) { + buffer_region_map.emplace(write->buffer, write->region); + } + + Array match_buffer_regions; + for (size_t i = 0; i < intrinsic->implementation->params.size(); ++i) { + const auto& param = intrinsic->implementation->params[i]; + const auto& buffer = intrinsic->implementation->buffer_map.at(param); + const auto& source = buffer_map.at(buffer); + Region region = buffer_region_map.at(buffer); + auto extra_indices = comparator.buffer_indices_.at(source); + std::vector extra_buffer_ranges; + std::transform(extra_indices.begin(), extra_indices.end(), + std::back_inserter(extra_buffer_ranges), + [](const PrimExpr& index) { return Range::FromMinExtent(index, 1); }); + region.insert(region.begin(), extra_buffer_ranges.begin(), extra_buffer_ranges.end()); + match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, region))); + } + + impl_block.CopyOnWrite()->match_buffers = match_buffer_regions; + std::unordered_map bv_map; + for (size_t i = 0; i < desc_block->iter_vars.size(); ++i) { + auto it = comparator.equal_map_.find(desc_block->iter_vars[i]->var); + if (it != comparator.equal_map_.end()) { + bv_map[impl_block->iter_vars[i]->var] = Downcast(it->second); + } else { + bv_map[impl_block->iter_vars[i]->var] = 0; + } + } + Stmt new_body = SubstituteInScope(impl_block, [&](const VarNode* var) -> PrimExpr { + auto it = bv_map.find(GetRef(var)); + if (it == bv_map.end()) + return GetRef(var); + else + return it->second; + }); + // Replace + ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); + new_block_ptr->body = Downcast(new_body)->body; + ICHECK(new_block_ptr->match_buffers.empty()); + new_block_ptr->match_buffers = Downcast(new_body)->match_buffers; + Block new_block(new_block_ptr); + self->Replace(self->stmt2ref.at(block_realize->block.get()), new_block, + {{block_realize->block, new_block}}); + RecalculateCachedFlags(self.operator->()); + // { + // struct BindingValidator : public StmtVisitor { + // void VisitStmt_(const BlockRealizeNode* realize) final { + // StmtSRef& sref = self->stmt2ref.at(realize->block.get()); + // UpdateAffineFlag(self, sref); + // VisitStmt(realize->block->body); + // } + // ScheduleState self; + // }; + // BindingValidator validator; + // StmtSRef block_sref = self->stmt2ref.at(new_block.get()); + // const PrimFuncNode* func = GetRootPrimFunc(self->mod, GetRootBlock(block_sref).get(), + // nullptr); validator.self = self; validator(func->body); + // } +} + +struct BlockizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Blockize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { + return sch->Blockize(loop_rv); + } + + static String UnpackedAsPython(Array outputs, String loop_rv) { + PythonAPICall py("blockize"); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +struct TensorizeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Tensorize"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 1; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String intrin_name) { + return sch->Tensorize(loop_rv, intrin_name); + } + + static String UnpackedAsPython(Array outputs, String loop_rv, String intrin_name) { + PythonAPICall py("tensorize"); + py.Input("loop", loop_rv); + py.Input("intrin", intrin_name); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); +TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 00886e8f8a22..8d3c151c6868 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -25,6 +25,49 @@ using support::NDIntSet; /******** Error Classes ********/ +/*! + * \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound) + * The bound region may not get directly intersected with dom region, instead we try to generate + * extra predicates for non-trivial bound. The domain info class can also union with each other. + */ +struct BlockVarDomainInfo { + arith::IntSet dom{arith::IntSet::Nothing()}; // dom is ensured to be bounded + arith::IntSet bound{arith::IntSet::Nothing()}; + + /*! \brief Relaxed union operation */ + void Union(const BlockVarDomainInfo& other) { + // just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1) + dom = arith::Union({dom, other.dom}); + bound = arith::Union({bound, other.bound}); + } + + /*! \brief Simplify domain info */ + void Simplify(arith::Analyzer* analyzer) { + auto to_simplified = [analyzer](const arith::IntSet& set) { + PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min(); + PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max(); + return arith::IntSet::Interval(min, max); + }; + // if no dom specified, try use bound as dom + if (dom.IsNothing()) { + if (bound.HasLowerBound() && bound.HasUpperBound()) { + bound = to_simplified(bound); + std::swap(dom, bound); + } + return; + } + // simplify intsets + dom = to_simplified(dom); + bound = to_simplified(bound); + // if can proof the dom is within bound, remove bound + auto intersect = to_simplified(arith::Intersect({dom, bound})); + if (analyzer->CanProveEqual(dom.min(), intersect.min()) && + analyzer->CanProveEqual(dom.max(), intersect.max())) { + bound = arith::IntSet::Nothing(); + } + } +}; + /*! * \brief An error raised when not all required blocks are under the given loop. * \tparam is_consumer Indicates if all the required blocks are consumers or producers @@ -181,7 +224,8 @@ class ScopeReconstructor : private StmtMutator { * \param iter_doms The domain of each block var * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1 */ - void MakeNewLoop(int insert_position, std::vector iter_doms, bool preserve_unit_loops) { + void MakeNewLoop(int insert_position, const std::vector& iter_doms, + arith::Analyzer* analyzer, bool preserve_unit_loops, ScheduleState self) { int n_iters = iter_doms.size(); Array loop_vars; Array loop_extents; @@ -189,19 +233,37 @@ class ScopeReconstructor : private StmtMutator { loop_vars.reserve(n_iters); loop_extents.reserve(n_iters); iter_values.reserve(n_iters); + PrimExpr predicate = const_true(); for (int i = 0; i < n_iters; ++i) { - const Range& iter_dom = iter_doms[i]; - if (preserve_unit_loops || !is_one(iter_dom->extent)) { + arith::IntSet pred_bound = iter_doms[i].bound; + arith::IntSet iter_dom_intset = iter_doms[i].dom; + Range iter_dom = iter_dom_intset.CoverRange(block_->iter_vars[i]->dom); + if (preserve_unit_loops || !is_one(iter_dom->extent) || !pred_bound.IsNothing()) { Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32)); loop_vars.push_back(var); loop_extents.push_back(iter_dom->extent); iter_values.push_back(iter_dom->min + var); + if (is_one(iter_dom->extent)) { + analyzer->Bind(var, 0); + } else { + analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent)); + } + if (pred_bound.HasLowerBound()) { + PrimExpr lower_bound = block_->iter_vars[i]->var >= pred_bound.min(); + predicate = predicate && lower_bound; + } + if (pred_bound.HasUpperBound()) { + PrimExpr upper_bound = block_->iter_vars[i]->var < pred_bound.max() + 1; + predicate = predicate && upper_bound; + } } else { iter_values.push_back(iter_dom->min); } } - this->new_block_realize_ = - BlockRealize(std::move(iter_values), const_true(), std::move(block_)); + if (!analyzer->CanProve(predicate)) { + this->predicate = predicate; + } + this->new_block_realize_ = BlockRealize(std::move(iter_values), Bool(true), std::move(block_)); Stmt new_subtree = this->new_block_realize_; for (int i = static_cast(loop_vars.size()) - 1; i >= 0; --i) { const Var& loop_var = loop_vars[i]; @@ -255,6 +317,8 @@ class ScopeReconstructor : private StmtMutator { Stmt rm_src_stmt_{nullptr}; /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ Stmt rm_tgt_stmt_{nullptr}; + /*! \brief Bound predicate for the given block to be moved */ + Optional predicate{NullOpt}; }; /*! @@ -310,17 +374,18 @@ void RelaxBufferRegions(const Map& binding, * domain * \param provided The provided integer set to cover the required domain * \param required The required domain to be covered + * \param required_bound The additional region bound of the required domain to be covered * \param iter_doms The result iteration domains to be updated * \param analyzer The arithmetic analyzer */ -void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required, - std::unordered_map>* iter_doms, - arith::Analyzer* analyzer) { +std::pair SolveBlockVarDomain(const arith::IntSet& provided, + const arith::IntSet& required, + arith::Analyzer* analyzer) { PrimExpr provided_min = analyzer->Simplify(provided.min()); - PrimExpr provided_extent = analyzer->Simplify(provided.max() - provided_min + 1); + PrimExpr provided_max = analyzer->Simplify(provided.max()); PrimExpr required_min = analyzer->Simplify(required.min()); - PrimExpr required_extent = analyzer->Simplify(required.max() - required_min + 1); - PrimExpr dom_min{nullptr}, dom_extent{nullptr}; + PrimExpr required_max = analyzer->Simplify(required.max()); + PrimExpr dom_min{nullptr}, dom_max{nullptr}; Var dom_var{ObjectPtr{nullptr}}; arith::PVar p_v; arith::PVar p_e; @@ -328,21 +393,58 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re PrimExpr e = p_e.Eval(); dom_var = p_v.Eval(); dom_min = floordiv(required_min, e); - dom_extent = analyzer->Simplify((required_extent + e - 1) / e); - } else if (analyzer->CanProveEqual(provided_extent, 1) && p_v.Match(provided_min)) { - dom_var = p_v.Eval(); - dom_min = required_min; - dom_extent = required_extent; - } else { - ICHECK(false) << "ValueError: BufferRegion pattern match failed"; + dom_max = floordiv(required_max, e); + } else if (analyzer->CanProveEqual(provided_min, provided_max)) { + if (p_v.Match(provided_min)) { + dom_var = p_v.Eval(); + dom_min = required_min; + dom_max = required_max; + } else { + arith::PVar p_f; + if ((floordiv(p_v, p_f)).Match(provided_min)) { + // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1) + PrimExpr fac = p_f.Eval(); + if (analyzer->CanProveGreaterEqual(fac, 1)) { + dom_var = p_v.Eval(); + dom_min = required_min * fac; + dom_max = analyzer->Simplify(required_max * fac + fac - 1); + } + } else if ((floormod(p_v, p_f).Match(provided_min))) { + // generally domain of (x % fac) enforce no constraints to domain of x + dom_var = p_v.Eval(); + return std::make_pair(dom_var, arith::IntSet::Nothing()); + } + } } - auto it = iter_doms->find(dom_var.get()); + ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min; + return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max)); +} + +/*! + * \brief Calculate the iteration domain of a provided integer set to fully cover the required + * domain + * \param provided The provided integer set to cover the required domain + * \param required The required domain to be covered + * \param required_bound The additional region bound of the required domain to be covered + * \param iter_doms The result iteration domains to be updated + * \param analyzer The arithmetic analyzer + */ +void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required, + const arith::IntSet& required_bound, + std::unordered_map* iter_doms, + arith::Analyzer* analyzer) { + auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer); + auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer); + const Var& var = var_with_dom.first; + const auto& var_dom = var_with_dom.second; + const auto& var_bound = var_with_bound.second; + ICHECK(var.same_as(var_with_bound.first)); + auto it = iter_doms->find(var.get()); if (it != iter_doms->end()) { - std::vector& doms = it->second; - doms.push_back(arith::IntSet::FromMinExtent(dom_min, dom_extent)); + it->second.Union({var_dom, var_bound}); } else { - ICHECK(analyzer->CanProveEqual(provided_min, required_min)); - ICHECK(analyzer->CanProveEqual(provided_extent, required_extent)); + ICHECK(analyzer->CanProveEqual(provided.min(), required.min())); + ICHECK(analyzer->CanProveEqual(provided.max(), required.max())); } } @@ -352,19 +454,19 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re * \param provided_regions The region provided by one iteration instance of the block vars * \param required_regions The region required to be covered * \param analyzer The arithmetic analyzer - * \return A list of iteration domain corresponding to the given list of block vars + * \return A list of iteration domain info corresponding to the given list of block vars */ -std::vector CalculateBlockVarDomain( +std::vector CalculateBlockVarDomain( const Array& iter_vars, std::unordered_map> provided_regions, std::unordered_map> required_regions, arith::Analyzer* analyzer) { int n_iters = iter_vars.size(); // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty) - std::unordered_map> iter_doms; + std::unordered_map iter_doms; iter_doms.reserve(n_iters); for (const IterVar& iter_var : iter_vars) { - iter_doms[iter_var->var.get()] = {}; + iter_doms[iter_var->var.get()] = BlockVarDomainInfo(); } // Step 2. For each buffer, update the domain according to the provided and required regions for (const auto& kv : provided_regions) { @@ -384,23 +486,23 @@ std::vector CalculateBlockVarDomain( for (int i = 0; i < ndim; ++i) { arith::IntSet provided = provided_region[i]; arith::IntSet required = required_region[i]; - required = arith::Intersect( - {std::move(required), arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i])}); - UpdateBlockVarDomain(provided, required, &iter_doms, analyzer); + arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i]); + UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, analyzer); } } // Union the iter var domains, put them in the same order of block vars, and return - std::vector result; + std::vector result; result.reserve(n_iters); for (const IterVar& iter_var : iter_vars) { - const std::vector& doms = iter_doms.at(iter_var->var.get()); - arith::IntSet dom = arith::IntSet::FromRange(iter_var->dom); - if (!doms.empty()) { - dom = arith::Intersect({std::move(dom), arith::Union(doms)}); + BlockVarDomainInfo& info = iter_doms.at(iter_var->var.get()); + if (info.bound.IsNothing()) { + info.bound = arith::IntSet::FromRange(iter_var->dom); + } else { + info.bound = arith::Intersect({info.bound, arith::IntSet::FromRange(iter_var->dom)}); } - PrimExpr min = analyzer->Simplify(dom.min()); - PrimExpr extent = analyzer->Simplify(dom.max() - min + 1); - result.push_back(Range::FromMinExtent(min, extent)); + info.Simplify(analyzer); + ICHECK(!info.dom.IsNothing()); + result.push_back(info); } return result; } @@ -450,9 +552,11 @@ void CalculateProvidedRequiredRegions( /******** Main Implementation ********/ template -void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, - const StmtSRef& loop_sref, bool preserve_unit_loops, - arith::Analyzer* analyzer, bool check_only = false) { +std::function ComputeAtOrReverseComputeAtImpl(ScheduleState self, + const StmtSRef& block_sref, + const StmtSRef& loop_sref, + bool preserve_unit_loops, + arith::Analyzer* analyzer) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks @@ -497,41 +601,47 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s /*consumer_srefs=*/std::move(consumer_srefs), /*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions); // Step 5. Calculate the iteration domain for each block var - std::vector iter_doms = + std::vector iter_doms = CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars, /*provided_regions=*/std::move(provided_regions), /*required_regions=*/std::move(required_regions), /*analyzer=*/analyzer); // Step 6. Create the new scope according to the iteration domain - reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), - /*preserve_unit_loops=*/preserve_unit_loops); + reconstructor.MakeNewLoop(/*insert_position=*/insert_position, + /*iter_doms=*/std::move(iter_doms), + /*analyzer=*/analyzer, + /*preserve_unit_loops=*/preserve_unit_loops, + /*state=*/self); Block new_scope_root = Downcast(reconstructor(scope_root)); - - // Step 7. Do the actual replacement - if (check_only) { - return; - } - self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); - // Step 8. Update the cached flags - BlockInfo& block_info = self->block_info[block_sref]; - block_info.affine_binding = IsAffineBinding( - /*realize=*/reconstructor.new_block_realize_, - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), - /*analyzer=*/analyzer); + Optional bound_predicate = reconstructor.predicate; + return [=]() -> void { + // Step 7. Do the actual replacement + self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); + // Step 8. Update the cached flags + BlockInfo& block_info = self->block_info[block_sref]; + block_info.affine_binding = IsAffineBinding( + /*realize=*/reconstructor.new_block_realize_, + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), + /*analyzer=*/analyzer); + // Step 9. Add bound predicate annotation for the block to be moved if needed + if (bound_predicate.defined()) { + Annotate(self, block_sref, attr::require_block_var_bound_predicate, bound_predicate.value()); + } + }; } void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer)(); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer)(); } bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, @@ -539,7 +649,7 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, true); + &analyzer); } catch (const tvm::runtime::Error& e) { return false; } @@ -551,7 +661,7 @@ bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, true); + &analyzer); } catch (const tvm::runtime::Error& e) { return false; } diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index fe2c679142b7..0c86f7b698aa 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -542,8 +542,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr producer_rhs_{nullptr}; }; -void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, - bool check_only = false) { +std::function ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -568,27 +567,24 @@ void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 6. Do the real mutation on the AST and the sref tree in the schedule state - if (check_only) { - return; - } - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; } void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { - ComputeInlineImpl(self, producer_block_sref); + ComputeInlineImpl(self, producer_block_sref)(); } bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) { try { - ComputeInlineImpl(self, producer_block_sref, true); + ComputeInlineImpl(self, producer_block_sref); } catch (const tvm::runtime::Error& e) { return false; } return true; } -void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, - bool check_only = false) { +std::function ReverseComputeInlineImpl(ScheduleState self, + const StmtSRef& consumer_block_sref) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); // Step 1. Get the scope block @@ -615,15 +611,12 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 7. Do the real mutation on the AST and the sref tree in the schedule state - if (check_only) { - return; - } - self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); + return [=]() -> void { self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); }; } bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { try { - ReverseComputeInlineImpl(self, block_sref, true); + ReverseComputeInlineImpl(self, block_sref); } catch (const tvm::runtime::Error& e) { return false; } @@ -631,7 +624,7 @@ bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sr } void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { - ReverseComputeInlineImpl(self, consumer_block_sref); + ReverseComputeInlineImpl(self, consumer_block_sref)(); } /******** InstructionKind Registration ********/ diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 55869e12b6b2..acab85460a71 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -83,7 +83,7 @@ void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, const Block& block = block_realize->block; // Cond 1. The block is required to have affine bindings. - CheckAffineBinding(self, block); + /* CheckAffineBinding(self, block); */ // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc new file mode 100644 index 000000000000..37b1dfc5d026 --- /dev/null +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class TransformLayoutRewriter : private StmtExprMutator { + public: + /*! + * \brief Rewrite the access to the buffer after the transformation + * \param scope_stmt The parent statement that contains all accesses to the target buffer + * \param old_buffer The target buffer before transformation + * \param new_buffer The new buffer after transformation + * \param index_map The transformation applied to the buffer + * \return The new AST rooting at the original parent scope and the map from the old block to the + * new block + */ + static std::pair> Rewrite(const Stmt& scope_stmt, + const Buffer& old_buffer, + const Buffer& new_buffer, + const IndexMap& index_map) { + TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map); + Stmt result = rewriter(scope_stmt); + return {result, rewriter.block_sref_reuse_}; + } + + private: + TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, + const IndexMap& index_map) + : old_buffer_(old_buffer), + new_buffer_(new_buffer), + index_map_(index_map), + buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {} + + void RewriteBufferAccess(Buffer* buffer, Array* indices) { + *buffer = new_buffer_; + *indices = index_map_->Apply(*indices); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad buffer_load = Downcast(StmtExprMutator::VisitExpr_(op)); + if (buffer_load->buffer.same_as(old_buffer_)) { + auto* n = buffer_load.CopyOnWrite(); + RewriteBufferAccess(&n->buffer, &n->indices); + } + return std::move(buffer_load); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore buffer_store = Downcast(StmtExprMutator::VisitStmt_(op)); + if (buffer_store->buffer.same_as(old_buffer_)) { + auto* n = buffer_store.CopyOnWrite(); + RewriteBufferAccess(&n->buffer, &n->indices); + } + return std::move(buffer_store); + } + + void RewriteAccessRegion(Array* old_access_regions, + const Array& infered_access_regions) { + auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { + if (buffer_region->buffer.same_as(old_buffer_)) { + ICHECK(infered_access_regions.size() == 1); + return infered_access_regions[0]; + } + return buffer_region; + }; + (*old_access_regions).MutateByApply(fmutate); + } + + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); + auto* n = block.CopyOnWrite(); + RewriteAccessRegion(&n->reads, infered_access_regions[0]); + RewriteAccessRegion(&n->writes, infered_access_regions[1]); + block_sref_reuse_.Set(GetRef(op), block); + return std::move(block); + } + + const Buffer& old_buffer_; + const Buffer& new_buffer_; + const IndexMap& index_map_; + Map buffer_data_to_buffer_; + Map block_sref_reuse_; +}; + +class BufferIsSubregionError : public ScheduleError { + public: + explicit BufferIsSubregionError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} + + String FastErrorString() const final { + return "ScheduleError: The input buffer is defined in `match_buffer` of a block, it is expected" + " to be a function parameter or allocated by a block"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: The input buffer " << buffer_->name << " is defined in `match_buffer` of " + << "a block, it is expected to be a function parameter or allocated by a block."; + return os.str(); + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + Buffer buffer_; +}; + +void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + bool is_write_index, const IndexMap& index_map) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, + /*is_write=*/!is_write_index); + Optional defining_site_sref; + bool is_alloc; + std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer); + if (defining_site_sref.defined() && !is_alloc) { + throw BufferIsSubregionError(self->mod, old_buffer); + } + + StmtSRef scope_sref = defining_site_sref.defined() + ? defining_site_sref.value() + : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false, + /*require_compact_dataflow*/ false); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + + // Step 1: Infer the shape of the new buffer + ObjectPtr new_buffer_node = make_object(*(old_buffer.get())); + new_buffer_node->shape = index_map->MapShape(old_buffer->shape); + Buffer new_buffer{new_buffer_node}; + + // Step 2: Rewrite access indices and regions of the buffer + Stmt new_stmt; + Map block_sref_reuse; + std::tie(new_stmt, block_sref_reuse) = TransformLayoutRewriter::Rewrite( + GetRef(scope_block), old_buffer, new_buffer, index_map); + Block new_scope_block = Downcast(new_stmt); + + // Step 3: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. + if (defining_site_sref.defined()) { + auto* n = new_scope_block.CopyOnWrite(); + n->alloc_buffers.MutateByApply([&old_buffer, &new_buffer](const Buffer& buffer) { + if (buffer.same_as(old_buffer)) { + return new_buffer; + } + return buffer; + }); + block_sref_reuse.Set(GetRef(scope_block), new_scope_block); + } else { + GlobalVar g_var; + GetRootPrimFunc(self->mod, scope_block, &g_var); + IRModuleNode* new_mod = self->mod.CopyOnWrite(); + MapNode* new_map = new_mod->functions.CopyOnWrite(); + PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); + PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); + MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite(); + for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { + if ((*it).second.same_as(old_buffer)) { + (*it).second = new_buffer; + } + } + new_map->at(g_var) = std::move(ref_new_func); + } + + // Step 4: Replace the scope block with the new block + self->Replace(scope_sref, new_scope_block, block_sref_reuse); +} + +/******** InstructionKind Registration ********/ + +struct TransformLayoutTraits : public UnpackedInstTraits { + static constexpr const char* kName = "TransformLayout"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + Bool is_write_index, IndexMap index_map) { + return sch->TransformLayout(block_rv, buffer_index, is_write_index, index_map); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + Bool is_write_index, IndexMap index_map) { + PythonAPICall py("transform_layout"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("is_write_index", is_write_index); + py.Input("index_map", index_map->ToPythonString()); + return py.Str(); + } + + public: + static ObjectRef AttrsAsJSON(const Array& attrs) { + Array attrs_record; + attrs_record.reserve(kNumAttrs); + attrs_record.push_back(attrs[0]); + attrs_record.push_back(attrs[1]); + attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); + return std::move(attrs_record); + } + + static Array AttrsFromJSON(const ObjectRef& attrs_record_) { + Array attrs_record = Downcast>(attrs_record_); + Array attrs; + attrs.push_back(attrs_record[0]); + attrs.push_back(attrs_record[1]); + attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); + return attrs; + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 7b9ac488b8b9..fa2a4469b8c9 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -413,7 +413,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, for (int i = 0; i < n; i++) { const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); - substitute_value = substitute_value * factor + var; + if (!is_one(factor)) substitute_value = substitute_value * factor + var; analyzer.Bind(var, Range::FromMinExtent(0, factor)); new_loop_vars.emplace_back(std::move(var)); } @@ -505,11 +505,14 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); Array substitute_value; substitute_value.resize(loops.size()); - PrimExpr tot = fused_var; - for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { - substitute_value.Set(i, floormod(tot, loops[i]->extent)); - tot = floordiv(tot, loops[i]->extent); - } + PrimExpr lower = 1; + for (int i = static_cast(loops.size()) - 1; i > 0; i--) { + substitute_value.Set(i, is_one(loops[i]->extent) + ? 0 + : floordiv(floormod(fused_var, lower * loops[i]->extent), lower)); + lower = lower * loops[i]->extent; + } + substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); Stmt new_stmt = loops.back()->body; Map opaque_block_reuse; auto f_substitute = [&](const Var& v) -> Optional { @@ -534,6 +537,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } + /*! * \brief Collect an array of loop srefs into a set * \param self The schedule state diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc new file mode 100644 index 000000000000..5d99abaef658 --- /dev/null +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + return true; + } + } + return false; +} + +void RelaxBufferRegions(const Array& buffer_regions, + const Buffer& buffer, // + const Map& var_dom, // + const Map& bindings, // + std::vector* relaxed_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + Array relaxed_region = + arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); + relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); + } + } +} + +class ScopeReplacer : public StmtMutator { + public: + static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = make_object(*scope_block); + new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); + new_scope_block->alloc_buffers.push_back(dst); + return Block(new_scope_block); + } + + private: + explicit ScopeReplacer(const ForNode* old_loop, const ForNode* new_loop) + : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} + + Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } + Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == old_loop_) { + found_ = true; + return GetRef(new_loop_); + } + return StmtMutator::VisitStmt_(loop); + } + + const ForNode* old_loop_; + const ForNode* new_loop_; + bool found_; +}; + +class ReadWriteAtBufferReplacer : public StmtExprMutator { + public: + explicit ReadWriteAtBufferReplacer(const Buffer& src, const Buffer& dst, + Map* block_sref_reuse) + : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + return BufferStore(new_store); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + return BufferLoad(new_load); + } + return load; + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block old_block = GetRef(_block); + Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = make_object(*block.get()); + new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); + new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); + block_sref_reuse_->Set(old_block, Block(new_block)); + return Block(new_block); + } + + const Buffer& src_; + const Buffer& dst_; + Map* block_sref_reuse_; +}; + +struct ReadWriteAtImpl { + template + static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope, + Map annotations) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer src = + GetNthAccessBuffer(self, GetRef(block), buffer_index, /*is_write=*/!is_read); + Buffer dst = WithScope(src, storage_scope); + ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); + std::pair new_loop_block = + impl.MakeLoopAndBlock(src->name + "_" + storage_scope); + StmtSRef result_block_sref = + impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); + impl.UpdateBlockInfo(result_block_sref); + return result_block_sref; + } + + private: + static Map GetLoopDomain(const StmtSRefNode* loop_sref) { + Map result; + for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; + loop_sref = loop_sref->parent) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return result; + } + + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_root_sref); + Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); + return self_->stmt2ref.at(new_block); + } + + void UpdateBlockInfo(const StmtSRef& new_block_sref) { + BlockInfo& block_info = self_->block_info[new_block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } + + template + std::pair MakeLoopAndBlock(const String& new_block_name_hint) { + Array subtrees = AsArray(loop_->body); + int n_subtrees = subtrees.size(); + runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); + std::vector relaxed_regions; + std::vector r_pos; + std::vector w_pos; + relaxed_regions.reserve(n_subtrees); + r_pos.reserve(n_subtrees); + w_pos.reserve(n_subtrees); + // Step 1. Iterate over all subtrees + for (int i = 0; i < n_subtrees; ++i) { + bool r_visited = false; + bool w_visited = false; + auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, + &scope](const ObjectRef& obj) -> bool { + const BlockRealizeNode* realize = obj.as(); + if (realize == nullptr) { + return true; + } + const BlockNode* block = realize->block.get(); + bool has_r = HasBuffer(block->reads, src_); + bool has_w = HasBuffer(block->writes, src_); + r_visited = r_visited || has_r; + w_visited = w_visited || has_w; + if (is_read ? has_r : has_w) { + RelaxBufferRegions( + /*buffer_regions=*/is_read ? block->reads : block->writes, + /*buffer=*/src_, + /*var_dom=*/ + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*high_exclusive=*/loop_sref_, + /*extra_relax_scope=*/scope)), + /*bindings=*/GetBindings(GetRef(realize)), + /*relaxed_regions=*/&relaxed_regions); + } + return false; + }; + PreOrderVisit(subtrees[i], f_visit); + if (r_visited) { + r_pos.push_back(i); + } + if (w_visited) { + w_pos.push_back(i); + } + } + // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement + int insert_pos = -1, st = -1, ed = -1; + if (is_read) { + ICHECK(!r_pos.empty()); + // No write after the first read + ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + // Can be inserted at [0, r_pos.front()], i.e. before the first read + insert_pos = r_pos.front(); + // Buffer reads in [insert_pos, +oo) is rewritten + st = insert_pos; + ed = n_subtrees; + } else { + ICHECK(!w_pos.empty()); + // No read after the last write + ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + // Can be inserted into (w_pos.back(), +oo), i.e. after the last write + insert_pos = w_pos.back() + 1; + st = 0; + ed = insert_pos; + } + // Step 3. Calculate `domain`, the domain of buffer access + NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); + int ndim = relaxed.size(); + Array domain; + domain.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IntSet& int_set = relaxed[i]; + PrimExpr min = analyzer_->Simplify(int_set.min()); + PrimExpr extent = analyzer_->Simplify(int_set.max() + 1 - min); + domain.push_back(Range::FromMinExtent(min, extent)); + } + // Step 4. Insert the auto copy block and replace buffers + ReadWriteAtBufferReplacer replacer(src_, dst_, &block_sref_reuse_); + for (int i = st; i < ed; ++i) { + Stmt stmt = subtrees[i]; + subtrees.Set(i, Stmt(nullptr)); + subtrees.Set(i, replacer(std::move(stmt))); + } + BlockRealize realize = + is_read + ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + subtrees.insert(subtrees.begin() + insert_pos, realize); + ObjectPtr new_loop = make_object(*loop_); + new_loop->body = SeqStmt(std::move(subtrees)); + return {For(new_loop), realize}; + } + + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, + const Map& loop_domain, Array domain) const { + int n = domain.size(); + std::vector loop_vars; + loop_vars.reserve(n); + for (int i = 0; i < n; ++i) { + loop_vars.push_back(Var("ax" + std::to_string(i))); + } + Map bindings; + Array iter_vars; + Array iter_values; + Array indices; + iter_vars.reserve(n); + iter_values.reserve(n); + indices.reserve(n); + for (int i = 0; i < n; ++i) { + auto f_substitute = [&loop_domain, &bindings, &iter_vars, + &iter_values](const Var& var) -> Optional { + auto it = bindings.find(var); + if (it != bindings.end()) { + return (*it).second; + } + Range range = loop_domain.at(var); + ObjectPtr v = make_object(*var.get()); + v->name_hint = "v" + std::to_string(iter_vars.size()); + bindings.Set(var, Var(v)); + iter_values.push_back(var); + iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); + return Var(v); + }; + ObjectPtr dom = make_object(*domain[i].get()); + dom->min = Substitute(std::move(dom->min), f_substitute); + dom->extent = Substitute(std::move(dom->extent), f_substitute); + domain.Set(i, Range(dom)); + } + for (int i = 0; i < n; ++i) { + indices.push_back(domain[i]->min + loop_vars[i]); + } + Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); + for (int i = n - 1; i >= 0; --i) { + stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + } + return BlockRealize( + /*values=*/iter_values, + /*predicate=*/const_true(), + Block(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); + } + + explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, + const Buffer& dst, Map annotations) + : self_(self), + loop_sref_(loop_sref), + loop_(nullptr), + src_(src), + dst_(dst), + annotations_(annotations), + block_sref_reuse_(), + analyzer_(std::make_unique()) { + loop_ = TVM_SREF_TO_FOR(loop_, loop_sref); + } + + ScheduleState self_; + const StmtSRef& loop_sref_; + const ForNode* loop_; + const Buffer& src_; + const Buffer& dst_; + Map annotations_; + Map block_sref_reuse_; + std::unique_ptr analyzer_; +}; + +StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, + {{"auto_copy", Integer(1)}}); +} + +StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, + storage_scope, {{"auto_copy", Integer(1)}}); +} + +/******** Instruction Registration ********/ + +struct ReadAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReadAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("read_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct WriteAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "WriteAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer write_buffer_index, String storage_scope) { + return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("write_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 096e616ec3ff..6c174560e954 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -281,7 +281,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, body = For(/*loop_var=*/new_loop_var, /*min=*/old_loop->min, /*extent=*/old_loop->extent, - /*kind=*/ForKind::kSerial, + /*kind=*/old_loop->kind, /*body=*/body); } body = Substitute(body, loop_var_map); @@ -485,13 +485,11 @@ class LoopPropertyError : public ScheduleError { if (loop.get() == rf_loop) { throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop); } - continue; } else if (reduction_touched) { if (!meet_reduction_loop) { CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); meet_reduction_loop = true; } - continue; } else if (meet_reduction_loop && !is_one(loop->extent)) { throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); } @@ -560,10 +558,13 @@ class BaseBlockCreator { } void CreateBlock() { - CreateAdditionalIter(); for (int i = 0; i < n_block_iters_; ++i) { CreateNormalIters(i); } + if (!additional_iter_.defined()) { + ICHECK(arith::Analyzer().CanProveEqual(rf_loop_->extent, Integer(1))); + CreateAdditionalIter(); + } CreateReductionUpdate(); CreateReadWriteRegions(); @@ -589,8 +590,8 @@ class BaseBlockCreator { } private: - virtual void CreateAdditionalIter() = 0; virtual void CreateNormalIters(int idx) = 0; + virtual void CreateAdditionalIter() = 0; virtual void CreateReductionUpdate() = 0; virtual void CreateReadWriteRegions() = 0; @@ -601,6 +602,8 @@ class BaseBlockCreator { BlockRealize new_block_realize_; /*! \brief The indices used to access the intermediate rfactor buffer */ Array rf_buf_access_indices_; + /*! \brief The additional block iter of the new created block for the rfactor loop. */ + IterVar additional_iter_; protected: /*! \brief The old block-realize */ @@ -672,15 +675,6 @@ class RFactorBlockCreator : public BaseBlockCreator { combiner_rhs_(std::move(combiner_rhs)) {} private: - void CreateAdditionalIter() final { - // Create a new data parallel block iter for the rfactor loop. - additional_iter_ = - IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar); - loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var; - iter_vars_.push_back(additional_iter_); - iter_values_.push_back(rf_loop_->loop_var); - } - void CreateNormalIters(int idx) final { IterVar old_iter = old_block_realize_->block->iter_vars[idx]; PrimExpr old_binding = old_block_realize_->iter_values[idx]; @@ -706,20 +700,31 @@ class RFactorBlockCreator : public BaseBlockCreator { } const For& loop = it->second; if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) { - // We haven't created the new block iter for `var`. So here we create it, append it - // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. - IterVar new_iter_var = - IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce); + // - We haven't created the new block iter for `var`. So here we create it, append it + // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. + // - If the loop is the rfactor loop, envoke `CreateAdditionalIter()`. + if (loop.same_as(rf_loop_)) { + CreateAdditionalIter(); + continue; + } + IterVar new_iter_var = IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, kCommReduce); loop_var2block_binding_[var.get()] = new_iter_var->var; iter_vars_.push_back(new_iter_var); iter_values_.push_back(var); } } // Substitute the original binding with new block iters. Store the result expression - // in `rf_var_map` for future substitution. + // in `var_map_` for future substitution. var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); } + void CreateAdditionalIter() final { + additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kDataPar); + iter_vars_.insert(iter_vars_.end(), additional_iter_); + iter_values_.insert(iter_values_.end(), rf_loop_->loop_var); + loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_; + } + void CreateReductionUpdate() final { rf_buf_access_indices_ = old_reduction_update_->indices; rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, @@ -754,10 +759,6 @@ class RFactorBlockCreator : public BaseBlockCreator { return new_regions; } - public: - /*! \brief The generated additional block iter in rfactor block for the rfactor loop */ - IterVar additional_iter_; - private: /*! * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction @@ -797,15 +798,6 @@ class WriteBackBlockCreator : public BaseBlockCreator { } private: - void CreateAdditionalIter() final { - // Create a new reduction block iter for the rfactor loop. - IterVar wb_new_block_iter = - IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); - iter_vars_.push_back(wb_new_block_iter); - iter_values_.push_back(rf_loop_->loop_var); - var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var); - } - void CreateNormalIters(int idx) final { IterVar old_block_iter = old_block_realize_->block->iter_vars[idx]; if (old_block_iter->iter_type == IterVarType::kDataPar) { @@ -813,9 +805,26 @@ class WriteBackBlockCreator : public BaseBlockCreator { kDataPar); iter_values_.push_back(old_block_realize_->iter_values[idx]); var_map_.Set(old_block_iter->var, iter_vars_.back()); + return; + } + + ICHECK(old_block_iter->iter_type == IterVarType::kCommReduce); + // If the old block iter touches the reduction loop and we have not created a new reduction + // block iter for the rfactor loop, create one now. + if (!additional_iter_.defined() && + UsesVar(old_block_realize_->iter_values[idx], + [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) { + CreateAdditionalIter(); } } + void CreateAdditionalIter() final { + additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); + iter_vars_.insert(iter_vars_.end(), additional_iter_); + iter_values_.insert(iter_values_.end(), rf_loop_->loop_var); + var_map_.Set(rf_additional_iter_->var, additional_iter_->var); + } + void CreateReductionUpdate() final { wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); wb_rhs_ = diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 171838572dbb..cbb4e66918e9 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -86,6 +86,7 @@ struct PrimeTable { pow_tab.emplace_back(std::move(tab)); } } + /*! * \brief Factorize a number n, and return in a cryptic format * \param n The number to be factorized @@ -187,6 +188,28 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st return candidates[i]; } +std::function MakeMultinomialSampler( + support::LinearCongruentialEngine::TRandState* rand_state, const std::vector& weights) { + ICHECK(!weights.empty()); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); + } + return [rng = support::LinearCongruentialEngine(rand_state).ForkSeed(), + dist = std::uniform_real_distribution(0.0, sum), + sums = std::move(sums)]() mutable -> int32_t { + support::LinearCongruentialEngine rand_(&rng); + double p = dist(rand_); + int32_t idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int32_t n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; +} + std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandState* rand_state, int32_t extent, int32_t n_splits) { CHECK_GE(extent, 1) << "ValueError: Cannot tile a loop with 0 or negative extent"; @@ -277,32 +300,22 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS return SamplePerfectTile(rand_state, extent, n_splits); } CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - std::vector innermost_candidates; - innermost_candidates.reserve(max_innermost_factor); - for (int32_t i = 1; i <= max_innermost_factor; ++i) { - if (extent % i == 0) { - innermost_candidates.push_back(i); + while (true) { + std::vector result = SamplePerfectTile(rand_state, extent, n_splits); + if (result.back() <= max_innermost_factor) { + return result; } } - // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. - // We should do multiple factorization to weight the choices. However, it would lead to slower - // sampling speed. On the other hand, considering potential tricks we might do on the innermost - // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add - // more heuristics in the future - int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; - std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); - result.push_back(innermost); - return result; } std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + const StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - int64_t extent = GetLoopIntExtent(loop); + const int64_t* extent = GetLoopIntExtent(loop); std::vector result; - if (extent == -1) { + if (extent == nullptr) { // Case 1. Handle loops with non-constant length result = std::vector(n_splits, 1); result[0] = -1; @@ -311,7 +324,7 @@ std::vector SamplePerfectTile( result = support::AsVector(decision->value()); int n = result.size(); ICHECK_GE(n, 2); - int64_t len = extent; + int64_t len = *extent; for (int i = n - 1; i > 0; --i) { int64_t& l = result[i]; // A previous decision could become invalid because of the change of outer tiles @@ -325,13 +338,49 @@ std::vector SamplePerfectTile( result[0] = len; } else { // Case 3. Use fresh new sampling result - result = SamplePerfectTile(rand_state, extent, n_splits, max_innermost_factor); + result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor); ICHECK_LE(result.back(), max_innermost_factor); } *decision = support::AsArray(result); return result; } +tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const StmtSRef& block_sref, Optional* decision) { + // Step 1. Collect all possible compute-at locations. + Array location_srefs; + std::vector location_indices; + std::tie(location_srefs, location_indices) = CollectComputeLocation(self, block_sref); + ICHECK_EQ(location_srefs.size(), location_indices.size()); + + // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the + // location candidates. Otherwise, pick the location before the previous decision. + // Step 3. If there was not a previous decision, sample a decision from the collected locations. + if (decision->defined()) { + int64_t old_decision = Downcast(*decision)->value; + auto it = std::lower_bound(location_indices.begin(), location_indices.end(), old_decision); + int idx = it - location_indices.begin(); + + if (it != location_indices.end() && *it == old_decision) { + *decision = Integer(old_decision); + return location_srefs[idx]; + } else if (it != location_indices.begin()) { + *decision = Integer(*--it); + return location_srefs[idx - 1]; + } else { + *decision = Integer(-1); + return StmtSRef::RootMark(); + } + } else { + int sampled_idx = SampleInt(rand_state, 0, location_indices.size()); + *decision = Integer(location_indices[sampled_idx]); + return location_srefs[sampled_idx]; + } + ICHECK(false) << "Cannot reach here"; + throw; +} + /******** InstructionKind Registration ********/ struct SampleCategoricalTraits : public UnpackedInstTraits { @@ -396,8 +445,37 @@ struct SamplePerfectTileTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleComputeLocation"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 1; + + static LoopRV UnpackedApplyToSchedule(Schedule sch, // + BlockRV block_rv, // + Optional decision) { + return sch->SampleComputeLocation(block_rv, decision); + } + + static String UnpackedAsPython(Array outputs, // + String block_rv, // + Optional decision) { + PythonAPICall py("sample_compute_location"); + py.Input("block", block_rv); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); TVM_REGISTER_INST_KIND_TRAITS(SamplePerfectTileTraits); +TVM_REGISTER_INST_KIND_TRAITS(SampleComputeLocationTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 75939f00b8f4..a925e66d6ecb 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -125,6 +125,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleComputeLocation") + .set_body_method(&ScheduleNode::SampleComputeLocation); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); @@ -163,6 +165,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); @@ -183,6 +189,20 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); /******** (FFI) Blockize & Tensorize ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize") + .set_body_method(&ScheduleNode::Blockize); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize") + .set_body_typed([](Schedule self, LoopRV loop_rv, ObjectRef intrin) { + if (const auto* str = intrin.as()) { + return self->Tensorize(loop_rv, GetRef(str)); + } + if (const auto* p_intrin = intrin.as()) { + return self->Tensorize(loop_rv, GetRef(p_intrin)); + } + LOG(FATAL) << "TypeError: Cannot handle type: " << intrin->GetTypeKey(); + throw; + }); + /******** (FFI) Annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, @@ -210,6 +230,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") throw; }); +/******** (FFI) Layout transformation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") + .set_body_method(&ScheduleNode::TransformLayout); /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 04b7dd5ea2af..0b67587c6163 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -338,6 +338,10 @@ class BlockInfoCollector : private StmtVisitor { /*dom_low_inclusive=*/parent_sref, /*dom_high_exclusive=*/lca, /*analyzer=*/&analyzer_); + for (size_t i = 0; i < consumed_region.size(); ++i) { + const arith::IntSet consumed_interset = arith::Intersect( + {consumed_region[i], arith::IntSet::FromMinExtent(0, buffer->shape[i])}); + } if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, &analyzer_)) { region_cover = false; @@ -897,7 +901,7 @@ class ChildReplacer : private StmtMutator { int seq_index_; }; -void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, +void ScheduleStateNode::Replace(const StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index d8c18f0de0d6..3ef22b358972 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -34,18 +34,13 @@ Trace::Trace(Array insts, Map decisions) { /**************** Utilities ****************/ -bool IsPostproc(const InstructionKind& inst_kind) { - static InstructionKind inst_enter_postproc = InstructionKind::Get("EnterPostproc"); - return inst_kind.same_as(inst_enter_postproc); -} - int GetNumValidInstructions(const Array& insts, bool remove_postproc) { if (!remove_postproc) { return insts.size(); } int n_insts = 0; for (const Instruction& inst : insts) { - if (!IsPostproc(inst->kind)) { + if (!inst->kind->IsPostproc()) { ++n_insts; } else { break; @@ -242,7 +237,7 @@ void TraceNode::ApplyToSchedule( decision_provider) const { std::unordered_map rv_map; for (const Instruction& inst : this->insts) { - if (remove_postproc && IsPostproc(inst->kind)) { + if (remove_postproc && inst->kind->IsPostproc()) { break; } Array inputs = TranslateInputRVs(inst->inputs, rv_map); @@ -266,7 +261,7 @@ ObjectRef TraceNode::AsJSON(bool remove_postproc) const { int i = 0; for (const Instruction& inst : this->insts) { const InstructionKind& kind = inst->kind; - if (remove_postproc && IsPostproc(kind)) { + if (remove_postproc && kind->IsPostproc()) { break; } json_insts.push_back(Array{ @@ -295,7 +290,7 @@ Array TraceNode::AsPython(bool remove_postproc) const { Array py_trace; py_trace.reserve(this->insts.size()); for (const Instruction& inst : this->insts) { - if (remove_postproc && IsPostproc(inst->kind)) { + if (remove_postproc && inst->kind->IsPostproc()) { break; } Array attrs; @@ -440,8 +435,10 @@ Trace TraceNode::Simplified(bool remove_postproc) const { } // Add its inputs as "used" ones for (const ObjectRef& obj : inst->inputs) { - if (obj->IsInstance() || obj->IsInstance() || - obj->IsInstance()) { + if (!obj.defined()) { + continue; + } else if (obj->IsInstance() || obj->IsInstance() || + obj->IsInstance()) { used_rvs.insert(obj.get()); continue; } else if (obj->IsInstance()) { diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 61283668f85d..fc92306354fe 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -29,7 +29,7 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); - support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); + n->Seed(seed); return Schedule(std::move(n)); } @@ -73,6 +73,20 @@ Array TracedScheduleNode::SamplePerfectTile(const LoopRV& loop_rv, int n return results; } +LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, + Optional decision) { + LoopRV result = CreateRV(tir::SampleComputeLocation(this->state_, &this->rand_state_, + this->GetSRef(block_rv), &decision)); + + static const InstructionKind& kind = InstructionKind::Get("SampleComputeLocation"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} + /******** Schedule: Get blocks & loops ********/ BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { @@ -250,6 +264,31 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReadAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("WriteAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -259,7 +298,7 @@ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_r static const InstructionKind& kind = InstructionKind::Get("ComputeAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops)}, + /*attrs=*/{Bool(preserve_unit_loops)}, /*outputs=*/{})); } @@ -270,7 +309,7 @@ void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, /*inputs=*/{block_rv, loop_rv}, - /*attrs=*/{Integer(preserve_unit_loops)}, + /*attrs=*/{Bool(preserve_unit_loops)}, /*outputs=*/{})); } @@ -342,6 +381,27 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, /******** Schedule: Blockize & Tensorize ********/ +BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv) { + BlockRV new_block = ConcreteScheduleNode::Blockize(loop_rv); + static const InstructionKind& kind = InstructionKind::Get("Blockize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{new_block})); + return new_block; +} + +void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) { + ConcreteScheduleNode::Tensorize(loop_rv, intrin_name); + static const InstructionKind& kind = InstructionKind::Get("Tensorize"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{intrin_name}, + /*outputs=*/{})); +} + /******** Schedule: Annotation ********/ void TracedScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key, @@ -382,6 +442,19 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k /*outputs=*/{})); } +/******** Schedule: Layout transformation ********/ + +void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, + bool is_write_index, const IndexMap& index_map) { + ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, is_write_index, index_map); + static const InstructionKind& kind = InstructionKind::Get("TransformLayout"); + trace_->Append( + /*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Bool(is_write_index), index_map}, + /*outputs=*/{})); +} + /******** Schedule: Misc ********/ void TracedScheduleNode::EnterPostproc() { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 5ce4763f117f..12696567816a 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -51,6 +51,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; + LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; @@ -72,6 +73,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -86,11 +92,16 @@ class TracedScheduleNode : public ConcreteScheduleNode { int offset) final; void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ + BlockRV Blockize(const LoopRV& loop_rv) final; + void Tensorize(const LoopRV& loop_rv, const String& intrin_name) final; /******** Schedule: Annotation ********/ void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const LoopRV& loop_rv, const String& ann_key) override; void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override; void Unannotate(const BlockRV& loop_rv, const String& ann_key) override; + /******** Schedule: Layout transformation ********/ + void TransformLayout(const BlockRV& block_rv, int buffer_index, bool is_write_index, + const IndexMap& index_map) override; /******** Schedule: Misc ********/ void EnterPostproc() final; }; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index ffb6b2d52628..fb3829c59a01 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -136,5 +136,98 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ throw OnlyLeafError(self->mod, GetRef(leaf_block), GetRef(scope_block)); } +/******** Utilities for tensorization ********/ + +class IRSubstituteInScope : public StmtExprMutator { + public: + explicit IRSubstituteInScope(std::function fmap) + : fmap_(std::move(fmap)) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = fmap_(op); + if (it.defined()) { + return it; + } else { + return GetRef(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + arith::Analyzer analyzer; + auto fmutate = [&](const PrimExpr& e) { return this->VisitExpr(e); }; + Array v = op->iter_values; + v.MutateByApply(fmutate); + PrimExpr pred = this->VisitExpr(op->predicate); + if (v.same_as(op->iter_values) && pred.same_as(op->predicate)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + n->predicate = std::move(analyzer.Simplify(pred)); + return Stmt(n); + } + } + + private: + const std::function fmap_; +}; + +Stmt SubstituteInScope(const Stmt& stmt, + const std::function& value_func) { + return IRSubstituteInScope(value_func)(stmt); +} + +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(stmt); +} + +PrimExpr SubstituteInScope(const PrimExpr& expr, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return it->second; + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(expr); +} + +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return GetRef(it->second); + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(stmt); +} + +PrimExpr SubstituteInScope(const PrimExpr& expr, + const std::unordered_map& var_map) { + auto vmap = [&](const VarNode* v) -> PrimExpr { + const auto& it = var_map.find(v); + if (it != var_map.end()) { + return GetRef(it->second); + } else { + return NullValue(); + } + }; + return IRSubstituteInScope(vmap)(expr); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index c66c2ca76693..144ba5f77e99 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -120,6 +120,21 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { return loops; } +/*! + * \brief Convert an array of block rvs to an array of block StmtSRefs + * \param sch The schedule used to evaluate the random variables + * \param block_rvs The random variables to be converted + * \return The conversion result srefs + */ +inline Array BlockRVs2StmtSRefs(const Schedule& sch, const Array& block_rvs) { + Array block_srefs; + block_srefs.reserve(block_rvs.size()); + for (const BlockRV& block_rv : block_rvs) { + block_srefs.push_back(sch->GetSRef(block_rv)); + } + return block_srefs; +} + /******** Storage scope ********/ /*! @@ -178,6 +193,18 @@ inline Array AsArray(const Stmt& stmt) { return {stmt}; } +/*! + * \brief Checks of a statement is a SeqStmt that contains multiple statements + * \param stmt The statement to be checked + * \return A boolean indicating the result + */ +inline bool IsSingleStmt(const Stmt& stmt) { + if (const auto* seq_stmt = stmt.as()) { + return seq_stmt->seq.size() == 1; + } + return true; +} + /******** IterVar ********/ /*! @@ -192,6 +219,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va Var(std::move(name), loop->loop_var.dtype()), iter_var_type); } +/*! + * \brief Get the thread scope bound to the specific loop + * \param loop The loop to be inspected + * \return The thread scope bound to the loop + */ +inline runtime::ThreadScope GetThreadScope(const ForNode* loop) { + if (loop->kind == ForKind::kThreadBinding) { + return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); + } + return runtime::ThreadScope{-1, -1}; +} + +/*! + * \brief Check if the thread scope is blockIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is blockIdx + */ +inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 0; // The rank of blockIdx is 0 +} + +/*! + * \brief Check if the thread scope is threadIdx + * \param thread_scope The thread scope to be checked + * \return True if the thread scope is threadIdx + */ +inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) { + return thread_scope.rank == 1 && thread_scope.dim_index >= 0; +} + /******** Integer set ********/ /*! @@ -210,28 +267,161 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/**************** Loop extents ****************/ +/**************** PrimExpr parsing and extents ****************/ /*! * \brief Get the extents of a loop * \param loop The loop to be queried - * \return The extents of the loop + * \return The extent of the loop, nullptr if the extent is not constant */ -inline int64_t GetLoopIntExtent(const ForNode* loop) { - const auto* int_extent = loop->extent.as(); - return int_extent ? int_extent->value : -1; -} +inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_int(loop->extent); } /*! * \brief Get the extents of a loop * \param loop_sref The loop to be queried - * \return The extents of the loop + * \return The extent of the loop, nullptr if the extent is not constant */ -inline int64_t GetLoopIntExtent(const StmtSRef& loop_sref) { +inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - return GetLoopIntExtent(loop); + return as_const_int(loop->extent); } +/*! + * \brief Check if an expression consists of a single variable, + * or a variable plus/minus an constant integer shift + * \param expr The expression to be checked + * \return result Output, the var if it satisfies the condition; otherwise NullOpt + */ +inline Optional AnalyzeVarWithShift(const PrimExpr& expr, Optional* constant) { + if (const auto* var = expr.as()) { + *constant = NullOpt; + return GetRef(var); + } + arith::PVar var; + arith::PVar shift; + // match: "var + shift" + if ((var + shift).Match(expr) || (shift + var).Match(expr)) { + *constant = shift.Eval(); + return var.Eval(); + } + // match: "var - shift" + if ((var - shift).Match(expr)) { + IntImm result = shift.Eval(); + *constant = IntImm(result->dtype, -result->value); + return var.Eval(); + } + return NullOpt; +} + +/******** Annotation ********/ + +/*! + * \brief Get the annotation on a Block/For + * \tparam TObjectRef The type of the annotation value + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be looked up + * \return NullOpt if not found; otherwise the annotation value + */ +template +inline Optional GetAnn(const TStmtNode* stmt, const String& ann_key) { + const Map* annotations = &stmt->annotations; + for (const auto& ann : *annotations) { + if (ann.first == ann_key) { + return Downcast(ann.second); + } + } + return NullOpt; +} + +/*! + * \brief Get the annotation on a Block/For + * \tparam TObjectRef The type of the annotation value + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be looked up + * \return NullOpt if not found; otherwise the annotation value + */ +template +inline Optional GetAnn(const StmtSRef& sref, const String& ann_key) { + if (const auto* loop = sref->StmtAs()) { + return GetAnn(loop, ann_key); + } else if (const auto* block = sref->StmtAs()) { + return GetAnn(block, ann_key); + } else { + LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); + throw; + } +} + +/*! + * \brief Check if a Block/For has a specific pair of annotation key and values + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be checked + * \param ann_val The string annotation value to be checked + * \return Whether a Block/For has a specific pair of annotation key and values + */ +inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& ann_val) { + Optional result = GetAnn(sref, ann_key); + return result.defined() && result.value() == ann_val; +} + +/*! + * \brief Check if a Block/For has a specific pair of annotation key and values + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be checked + * \param ann_val The string annotation value to be checked + * \return Whether a Block/For has a specific pair of annotation key and values + */ +inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const char* ann_val) { + Optional result = GetAnn(sref, ann_key); + return result.defined() && result.value() == ann_val; +} + +/*! + * \brief Check if a Block/For has a specific pair of annotation key and values + * \param sref The sref to the block or the for loop + * \param ann_key The annotation key to be checked + * \param ann_val The boolean annotation value to be checked + * \return Whether a Block/For has a specific pair of annotation key and values + */ +inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { + Optional result = GetAnn(sref, ann_key); + return result.defined() && result.value()->value == ann_val; +} + +/******** Tensorization ******/ +/*! + * \brief Rewrite the block's outer loops to match the tensor intrin + * \param sch The schedule + * \param block_rv The block_rv we want to rewrite + * \param intrin_name The name of the tensor intrin we want to match + */ +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name); + +/*! + * \brief Substitute the var in current block scope specified in key->var to be value. + * \param stmt The source stmt to be substituted + * \param value_func The function of new values mapping. + * \return The converted stmt. + */ +Stmt SubstituteInScope(const Stmt& stmt, const std::function& value_func); + +/*! + * \brief Substitute the var in current block scope specified in var map + * \param stmt The source stmt to be substituted + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map); + +/*! + * \param var_map The mapping of var + * \return The converted stmt + */ +Stmt SubstituteInScope(const Stmt& stmt, + const std::unordered_map& var_map); + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/apply_block_bound_predicate.cc b/src/tir/transforms/apply_block_bound_predicate.cc new file mode 100644 index 000000000000..2e93f4b13063 --- /dev/null +++ b/src/tir/transforms/apply_block_bound_predicate.cc @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file apply_block_bound_predicate.cc + * \brief Apply the block iter bound predicate to loops. + */ + +#include +#include +#include + +#include "../../arith/pattern_match.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +class BoundPredicateParserSimplifier : public ExprMutator { + public: + explicit BoundPredicateParserSimplifier(Map binding_map, + Map* bound_intset) + : binding_map_(std::move(binding_map)), bound_intset_(bound_intset) {} + + private: + PrimExpr VisitExpr(const PrimExpr& expr) final { + if (expr->IsInstance() || expr->IsInstance() || expr->IsInstance()) { + return ExprMutator::VisitExpr(expr); + } + ICHECK(false) << "InternalError: PrimExpr \"" << expr + << "\" is not supposed to appear as a bound predicate"; + throw; + } + + PrimExpr VisitExpr_(const LTNode* lt) final { + const VarNode* var = lt->a.as(); + if (!var) { + ICHECK(false) << "InternalError: LHS of logical expression here is required to be variables"; + } + Optional binding = binding_map_.Get(GetRef(var)); + if (!binding.defined()) { + ICHECK(false) << "InternalError: The LHS variable is supposed to be a block iterator"; + } + const VarNode* loop_var = binding.value().as(); + if (!loop_var) { + return GetRef(lt); + } + + arith::IntSet intset = + bound_intset_->Get(GetRef(loop_var)).value_or(arith::IntSet::Everything()); + intset = arith::Intersect( + {intset, arith::IntSet::FromRange(Range(min_value(lt->b.dtype()), lt->b))}); + bound_intset_->Set(GetRef(loop_var), intset); + return const_true(); + } + + PrimExpr VisitExpr_(const GENode* ge) final { + const VarNode* var = ge->a.as(); + if (!var) { + ICHECK(false) << "InternalError: LHS of logical expression here is required to be variables"; + } + Optional binding = binding_map_.Get(GetRef(var)); + if (!binding.defined()) { + ICHECK(false) << "InternalError: The LHS variable is supposed to be a block iterator"; + } + const VarNode* loop_var = binding.value().as(); + if (!loop_var) { + return GetRef(ge); + } + + arith::IntSet intset = + bound_intset_->Get(GetRef(loop_var)).value_or(arith::IntSet::Everything()); + intset = arith::Intersect( + {intset, arith::IntSet::FromRange(Range(ge->b, max_value(ge->b.dtype())))}); + bound_intset_->Set(GetRef(loop_var), intset); + return const_true(); + } + + Map binding_map_; + Map* bound_intset_; +}; + +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + */ +class LoopExtentMutator : public StmtMutator { + private: + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // Step 1. Mutate recursively. + BlockRealize new_realize = Downcast(StmtMutator::VisitStmt_(realize)); + // Step 2. If the block has no "require_block_var_bound_predicate" annotation, skip this block. + Block block = new_realize->block; + const Optional& bound_predicate = + block->annotations.Get(tir::attr::require_block_var_bound_predicate); + if (!bound_predicate.defined()) { + return new_realize; + } + // Step 3. Make a mapping from block iters to bindings. + Map binding_map; + ICHECK_EQ(block->iter_vars.size(), new_realize->iter_values.size()); + int n_iter = static_cast(block->iter_vars.size()); + for (int i = 0; i < n_iter; ++i) { + binding_map.Set(block->iter_vars[i]->var, new_realize->iter_values[i]); + } + // Step 4. Parse the bound predicate, removing constraints on the block vars whose binding are + // single vars. + PrimExpr new_predicate = BoundPredicateParserSimplifier( + binding_map, &bound_intset_)(Downcast(bound_predicate.value())); + // Step 5. Update the block annotation and update the new block-realize. + ObjectPtr p_new_block = CopyOnWrite(block.get()); + if (ana_.CanProveEqual(new_predicate, const_true())) { + p_new_block->annotations.erase(tir::attr::require_block_var_bound_predicate); + } else { + p_new_block->annotations.Set(tir::attr::require_block_var_bound_predicate, new_predicate); + } + ObjectPtr p_new_realize = CopyOnWrite(new_realize.get()); + p_new_realize->block = Block(p_new_block); + + return BlockRealize(p_new_realize); + } + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Mutate recursively. + For new_loop = Downcast(StmtMutator::VisitStmt_(loop)); + // Step 2. Check whether this loop has a bound intset. If not, return the new loop. + Optional intset = bound_intset_.Get(new_loop->loop_var); + if (!intset.defined()) { + return new_loop; + } + // Step 3. Update the new loop's `min` and `extent` according to the extent. + PrimExpr new_min = max(new_loop->min, intset.value().min()); + PrimExpr new_extent = min(new_loop->min + new_loop->extent, intset.value().max() + 1) - new_min; + // Step 4. Update the new loop. + ObjectPtr p_new_loop = CopyOnWrite(new_loop.get()); + p_new_loop->min = ana_.Simplify(new_min); + p_new_loop->extent = ana_.Simplify(new_extent); + + return For(p_new_loop); + } + + /*! \brief The bounds of loop vars, provided by the block iter bound predicate */ + Map bound_intset_; + /*! \brief The analyzer */ + arith::Analyzer ana_; +}; + +PrimFunc ApplyBlockBoundPredicate(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = LoopExtentMutator()(f->body); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass ApplyBlockBoundPredicate() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ApplyBlockBoundPredicate(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ApplyBlockBoundPredicate", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ApplyBlockBoundPredicate") + .set_body_typed(ApplyBlockBoundPredicate); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 07f977860d93..86ddb53da9ed 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -53,12 +53,29 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, for (size_t i = 0; i < nd_int_set.size(); ++i) { const arith::IntSet& int_set = nd_int_set[i]; Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])); - result.push_back( - Range::FromMinExtent(analyzer->Simplify(range->min), analyzer->Simplify(range->extent))); + result.push_back(Range::FromMinExtent( + analyzer->Simplify(range->min), analyzer->Simplify(min(original_shape[i], range->extent)))); } return result; } +NDIntSet NDIntSetEval(Region region, PrimExpr predicate, + std::unordered_map& dom_map, + arith::Analyzer* analyzer) { + std::unordered_map var_dom; + for (const auto& it : dom_map) { + var_dom[GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); + } + Optional> eval_res = + arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer); + if (eval_res.defined()) { + NDIntSet res(0); + for (const auto& it : eval_res.value()) res.push_back(it); + return res; + } + return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map); +} + /*! * \brief Collect the access region of each buffer. * \note The param buffer regions will not be collected. @@ -149,7 +166,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } return; } - return StmtExprVisitor::VisitExpr_(op); + StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BlockNode* op) final { @@ -198,6 +215,13 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } } + void VisitStmt_(const BlockRealizeNode* op) final { + PrimExpr cur_predicate = predicate_in_scope; + predicate_in_scope = op->predicate; + StmtExprVisitor::VisitStmt_(op); + predicate_in_scope = cur_predicate; + } + /**************** Helper functions ****************/ void VisitBufferAccess(const BufferRegion& buffer_region) { @@ -206,7 +230,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor { if (it != buffer_var_in_scope_.end()) { const Buffer& buffer = it->second.first; size_t n_ancestor_loops = it->second.second; - NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region); // Step 1. Stop ancestor loop vars out of the allocation block from // being relaxed unless NeedRelaxThread() is true. std::vector non_relaxed(n_ancestor_loops); @@ -222,7 +245,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { dom_map_.erase(dom_it); } // Step 2. Relax the access region - nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_); + NDIntSet nd_int_set = + NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_); // Step 3. Restore the non-relaxed ancestor loops domain for (size_t i = 0; i < n_ancestor_loops; ++i) { const VarNode* v = ancestor_loops_[i]->loop_var.get(); @@ -279,6 +303,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_var_in_scope_; + /*! \brief The block predicate of current scope */ + PrimExpr predicate_in_scope{true}; /*! \brief The map from loop vars to their iter range. */ std::unordered_map dom_map_; diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index e9d99cda7e13..e3b32cb6c460 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -65,6 +65,12 @@ class BufferFlattener : public StmtExprMutator { if (!is_one(predicate)) { body = IfThenElse(predicate, std::move(body)); } + // If the block has bound predicates, transform it to if-then-else + const Optional& bound_predicate = + new_block->annotations.Get(tir::attr::require_block_var_bound_predicate); + if (bound_predicate.defined()) { + body = IfThenElse(Downcast(bound_predicate.value()), std::move(body)); + } // Step 3. Handle allocations in reverse order for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { const Buffer& buffer = new_block->alloc_buffers[i - 1]; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc new file mode 100644 index 000000000000..a5cecc1d4707 --- /dev/null +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file inject_software_pipeline.cc + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers + */ +#include +#include +#include + +#include "../../support/utils.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" + +namespace tvm { +namespace tir { + +namespace software_pipeline { + +Block MakeBlock(const Stmt& body, const Map& buffer_data_to_buffer) { + Block block = Block({}, {}, {}, "", body); + auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); + auto* n = block.CopyOnWrite(); + n->reads = access[0]; + n->writes = access[1]; + return block; +} + +struct PipelineStageOrder { + int stage; + int order; + PipelineStageOrder(int stage, int order) : stage(stage), order(order) {} +}; + +using PipelineInfo = std::unordered_map; + +struct BufferAccessInfo { + int def; // the defining stage of the buffer + int use; // the last using stage of the buffer + BufferAccessInfo(int def = -1, int use = -1) : def(def), use(use){}; +}; + +/*! + * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices + * of accessing to remapped buffer to select the version corresponding to the pipeline stage. + */ +class PipelineBodyRewriter : public StmtExprMutator { + public: + /*! + * \brief Constructor of PipelineBodyRewriter. + * \param buffer_data_to_buffer The map from buffer data to buffer. + * \param buffer_remap The map from original buffer to the buffer with updated shape for + * multi-versioning in the sofeware pipeline. + * \param pipeline_loop The original loop to be software pipelined. + * \param access_all_versions Whether all versions the the buffers in the software pipeline are + * accessed. This will be used to update block access region. In the prologue and epilogue + * of a two-stage software pipeline, only one version of these buffers are accessed. + */ + PipelineBodyRewriter(const Map& buffer_data_to_buffer, + const Map& buffer_remap, For pipeline_loop, + bool access_all_versions) + : buffer_data_to_buffer_(buffer_data_to_buffer), + buffer_remap_(buffer_remap), + pipeline_loop_(pipeline_loop), + access_all_versions_(access_all_versions) {} + + private: + BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { + auto it = buffer_remap_.find(buffer_region->buffer); + if (it != buffer_remap_.end()) { + Region new_region = buffer_region->region; + const Buffer& new_buffer = (*it).second; + // For pipeline buffers, always relax the access region of the first dimension to full extent + Range accessed_version = + access_all_versions_ + ? Range::FromMinExtent(0, new_buffer->shape[0]) + : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), + new_buffer->shape[0]), + Integer(1)); + new_region.insert(new_region.begin(), accessed_version); + return BufferRegion(new_buffer, new_region); + } + return buffer_region; + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + BlockNode* n = block.CopyOnWrite(); + n->reads.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + n->writes.MutateByApply( + std::bind(&PipelineBodyRewriter::RewritePipelineBufferRegion, this, std::placeholders::_1)); + for (const Buffer& alloc_buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(alloc_buffer->data); + } + return block; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + auto it = buffer_remap_.find(store->buffer); + if (it == buffer_remap_.end()) { + return std::move(store); + } + const Buffer& new_buffer = (*it).second; + auto* n = store.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + auto it = buffer_remap_.find(load->buffer); + if (it == buffer_remap_.end()) { + return std::move(load); + } + const Buffer& new_buffer = (*it).second; + auto* n = load.CopyOnWrite(); + n->buffer = new_buffer; + PrimExpr version = + floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); + n->indices.insert(n->indices.begin(), version); + return std::move(load); + } + + PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, + const PrimExpr& old_index) { + PrimExpr new_buffer_offset = old_index; + + const int fragment_size = 256; + PrimExpr offset = + floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), old_buffer->shape), + fragment_size); + new_buffer_offset += + floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; + return new_buffer_offset; + } + + PrimExpr VisitExpr_(const CallNode* op) final { + // Intrinsic calls should be handled explicitly here as they are opaque accesses to + // buffer. + static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); + static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); + static const auto& mma_sync = builtin::tvm_mma_sync(); + static const auto& access_ptr = builtin::tvm_access_ptr(); + Call call = Downcast(StmtExprMutator::VisitExpr_(op)); + if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); + return Call(call->dtype, call->op, new_args, call->span); + } + } else if (call->op.same_as(mma_sync)) { + Array new_args = call->args; + for (int i = 0; i < 4; i++) { + const Var& buffer_var = Downcast(call->args[i * 2]); + const PrimExpr& index = call->args[i * 2 + 1]; + const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); + new_args.Set(i * 2 + 1, new_index); + } + } + return Call(call->dtype, call->op, new_args, call->span); + } else if (call->op.same_as(access_ptr)) { + const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[1])); + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + Array new_args = call->args; + const Buffer& new_buffer = (*it).second; + const PrimExpr& old_index = call->args[2]; + PrimExpr offset; + if (new_buffer->strides.empty()) { + offset = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, + make_const(DataType::Int(32), 1), buffer->shape); + } else { + offset = new_buffer->strides[0]; + } + PrimExpr new_index = old_index + floormod(pipeline_loop_->loop_var, 2) * offset; + new_args.Set(2, new_index); + return Call(call->dtype, call->op, new_args, call->span); + } + } + return std::move(call); + } + + Map buffer_data_to_buffer_; + Map buffer_remap_; + For pipeline_loop_; + bool access_all_versions_; +}; + +class PipelineRewriter : public StmtExprMutator { + public: + static Stmt Rewrite( + Map buffer_data_to_buffer, + const std::unordered_set& double_buffers, + const Array pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info) { + PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, + pipeline_info); + return rewriter.BuildPipeline(); + } + + private: + PipelineRewriter(Map buffer_data_to_buffer, + const std::unordered_set& double_buffers, + const Array& pipeline_allocs, const For& pipeline_loop, + const PipelineInfo& pipeline_info) + + : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), + double_buffers_(double_buffers), + pipeline_allocs_(pipeline_allocs), + pipeline_loop_(pipeline_loop), + pipeline_info_(pipeline_info) {} + + Stmt BuildPipeline() { + // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions + // need to maintain for each buffer. + RemapPipelineBuffers(pipeline_allocs_); + + ordered_stmts_.resize(pipeline_info_.size()); + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int order = pair.second.order; + ordered_stmts_.Set(order, block); + } + + // Step 2: Emit the pipeline prologue, body and epilogue. + Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); + Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, + pipeline_loop_->min + pipeline_loop_->extent, false); + Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, + pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true); + + SeqStmt stmt = SeqStmt({prologue, body, epilogue}); + + // Step 3: Add annotations of nested software pipeline (if appliable) + stmt = AnnotateNestedPipeline(stmt); + + // Step 4: Make a new block that contains new buffer allocations after pipeline rewriting. + Array alloc_buffers; + for (const auto& alloc : pipeline_allocs_) { + auto it = buffer_remap_.find(alloc); + if (it != buffer_remap_.end()) { + alloc_buffers.push_back((*it).second); + } else { + alloc_buffers.push_back(alloc); + } + buffer_data_to_buffer_.erase(alloc->data); + } + Block block = MakeBlock(stmt, buffer_data_to_buffer_); + auto* n = block.CopyOnWrite(); + n->alloc_buffers = std::move(alloc_buffers); + return BlockRealize({}, Bool(true), block); + } + + private: + /*! + * \brief Annotate the result of software pipeline rewriting with user-provided annotations. + * + * When there are nested software pipelines, after rewriting the inner software pipeline, + * it is required to add annotations to the result of the inner software pipeline to specify + * the rewriting behavior of the outer software pipeline. + * This method expects the annotations `attr::nested_software_pipeline_order`, and + * `attr::nested_software_pipeline_stage` are present on the inner software pipeline loop. + * + * \param pipeline_seq The sequence of statements after pipeline rewriting, which consists of + * three BlockRealize that represents the prologue, the body, and the epilogue of the software + * pipeline. + * \return The sequence of the statements that consists of the annotated software pipeline. + */ + SeqStmt AnnotateNestedPipeline(const SeqStmt& pipeline_seq) { + auto it = pipeline_loop_->annotations.find(attr::nested_software_pipeline_stage); + if (it == pipeline_loop_->annotations.end()) { + return pipeline_seq; + } + Array nested_stage = Downcast>((*it).second); + CHECK(pipeline_loop_->annotations.count(attr::nested_software_pipeline_order)) + << "ValueError: Annotation for the order of the nested software pipeline is missing."; + Array nested_order = Downcast>( + pipeline_loop_->annotations.at(attr::nested_software_pipeline_order)); + CHECK_EQ(nested_stage.size(), 3) << "ValueError: Annotation for the stage of the nested " + "software pipeline should be a 3-tuple"; + CHECK_EQ(nested_order.size(), 3) << "ValueError: Annotation for the order of the nested " + "software pipeline should be a 3-tuple"; + Array new_seq; + new_seq.reserve(pipeline_seq->seq.size()); + for (size_t i = 0; i < pipeline_seq->seq.size(); i++) { + BlockRealize block_realize = Downcast(pipeline_seq->seq[i]); + auto* block = block_realize.CopyOnWrite()->block.CopyOnWrite(); + block->annotations.Set(attr::software_pipeline_stage, nested_stage[i]); + block->annotations.Set(attr::software_pipeline_order, nested_order[i]); + new_seq.push_back(std::move(block_realize)); + } + return SeqStmt(std::move(new_seq)); + } + + /*! + * \brief Analyze accesses to the buffers in the software pipeline. + * + * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which + * can be used to compute the number of versions needed to maintain after rewriting. + */ + std::unordered_map + GetBufferAccessInfo() { + std::unordered_map infos; + for (const auto& pair : pipeline_info_) { + const Block& block = pair.first; + int stage = pair.second.stage; + max_stage_ = std::max(max_stage_, stage); + + for (const BufferRegion& write : block->writes) { + if (!infos.count(write->buffer)) { + infos.emplace(write->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(write->buffer); + if (info.def == -1) { + info.def = stage; + } + } + + for (const BufferRegion& read : block->reads) { + if (!infos.count(read->buffer)) { + infos.emplace(read->buffer, BufferAccessInfo{}); + } + auto& info = infos.at(read->buffer); + info.use = std::max(info.use, stage); + } + } + return infos; + } + + /*! + * \brief Check whether two regions have intersections. + * \param region1 The first region. + * \param region2 The second region. + * \return Whether region1 and region2 have intersections. + */ + bool MayConflict(Region region1, Region region2) { + ICHECK(region1.size() == region2.size()); + for (size_t i = 0; i < region1.size(); i++) { + Range dim1 = region1[i]; + Range dim2 = region2[i]; + auto int_set1 = arith::IntSet::FromRange(dim1); + auto int_set2 = arith::IntSet::FromRange(dim2); + if (arith::Intersect({int_set1, int_set2}).IsNothing()) { + return false; + } + } + return true; + } + + /*! + * \brief Compute the number of versions need to maintain for buffer accessed in the software + * pipeline. + * + * This method applies liveness analysis to the target buffer to compute the number of versions + * need to maintain during the software pipeline. + * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the + * result of the analysis. Additional double buffering in the software pipeline can be useful + * to eliminate synchonizations in GPU devices. + * + * \param buffer The target buffer + * \param buffer_info The access information of the target buffer. + * \return The number of versions required for the target buffer. + */ + int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) { + if (buffer_info.def == -1) { + // Keep the original number of versions as buffers defined outside the software pipeline + // should not be mutated. + return 1; + } + + // `use - def + 1` is a upper bound of the needed versions + // We optimize a few case where the number of versions can be smaller than the upper bound + int num_versions = buffer_info.use - buffer_info.def + 1; + if (num_versions == 2) { + // A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when + // these exists a reader block_i and a writer block_j such that + // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions + // of block_i and block_j overlap. + bool need_multi_version = false; + for (const auto& pair1 : pipeline_info_) { + const Block& writer_block = pair1.first; + const auto& writer_info = pair1.second; + + auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), + [&](const BufferRegion& buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it1 == writer_block->writes.end()) { + continue; + } + + for (const auto& pair2 : pipeline_info_) { + const Block& reader_block = pair2.first; + const auto& reader_info = pair2.second; + auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), + [&](const BufferRegion& buffer_region) { + return buffer_region->buffer.same_as(buffer); + }); + if (it2 == reader_block->reads.end()) { + continue; + } + if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage && + MayConflict((*it1)->region, (*it2)->region)) { + need_multi_version = true; + break; + } + } + } + if (!need_multi_version) { + num_versions = 1; + } + } + if (num_versions == 1 && double_buffers_.count(buffer)) { + num_versions = 2; + } + return num_versions; + } + + /*! + * \brief Rewrite buffer allocations to create new buffers with new shapes according to + * the software pipeline. + * \param pipeline_allocs The buffer allocations inside the software pipeline scope. + */ + void RemapPipelineBuffers(Array pipeline_allocs) { + std::unordered_map infos = + GetBufferAccessInfo(); + for (const auto& pair : infos) { + const Buffer& buffer = pair.first; + const BufferAccessInfo& buffer_info = pair.second; + int num_versions = ComputeBufferVersions(buffer, buffer_info); + if (num_versions > 1) { + Buffer new_buffer = RewriteAllocBuffer(buffer, num_versions); + CHECK(std::find(pipeline_allocs.begin(), pipeline_allocs.end(), buffer) != + pipeline_allocs.end()); + buffer_remap_.Set(pair.first, new_buffer); + } + } + } + + /*! + * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined + * accesses. + * \param buffer The buffer to be resized. + * \param num_versions The number of versions to keep. + * \return The resized buffer. + */ + Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { + ObjectPtr new_buffer = make_object(*(buffer.get())); + new_buffer->shape.insert(new_buffer->shape.begin(), num_versions); + if (new_buffer->strides.size()) { + ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); + PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; + new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); + } + return Buffer(new_buffer); + } + + Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { + Array stmts; + PrimExpr new_loop_var; + bool is_unit_loop = analyzer_.CanProveEqual(start + 1, end); + if (is_unit_loop) { + new_loop_var = start; + } else { + new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); + analyzer_.Bind(Downcast(new_loop_var), Range(start, end), true); + } + + for (const Block block : ordered_stmts_) { + int stage = pipeline_info_.at(block).stage; + PrimExpr skewed_loop_var = new_loop_var - stage; + PrimExpr inbound = (skewed_loop_var >= pipeline_loop_->min) && + (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); + inbound = analyzer_.Simplify(inbound); + if (analyzer_.CanProve(!inbound)) { + continue; + } + Block new_block = Downcast(PipelineBodyRewriter( + buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, max_stage_ != 1)(block)); + Map subst_map; + if (is_unit_loop) { + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var); + } else { + // normalize loop range + subst_map.Set(pipeline_loop_->loop_var, skewed_loop_var + (start - pipeline_loop_->min)); + } + new_block = Downcast(Substitute(new_block, subst_map)); + stmts.push_back(BlockRealize({}, inbound, new_block)); + } + + Stmt stmt; + if (is_unit_loop) { + stmt = stmts.size() == 1 ? stmts[0] : SeqStmt(stmts); + } else { + stmt = For(Downcast(new_loop_var), pipeline_loop_->min, end - start, + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, SeqStmt(stmts)); + } + if (stmt->IsInstance()) { + return stmt; + } + return BlockRealize({}, Bool(true), MakeBlock(stmt, buffer_data_to_buffer_)); + } + + arith::Analyzer analyzer_; + Map buffer_data_to_buffer_; + const std::unordered_set& double_buffers_; + Array pipeline_allocs_; + For pipeline_loop_; + PipelineInfo pipeline_info_; + int max_stage_ = -1; + Map buffer_remap_; + Array ordered_stmts_; +}; + +class PipelineInjector : private StmtExprMutator { + public: + static Stmt Inject(const PrimFunc& func) { + PipelineInjector injector; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + injector.buffer_data_to_buffer_.Set(buffer->data, buffer); + } + return injector(func->body); + } + + private: + PipelineInjector() = default; + + PipelineStageOrder CheckAndRemovePipelineAnnotation(Map* annotations) const { + CHECK(annotations->count(attr::software_pipeline_stage)) + << "ValueError: Stage of the statement in the software pipeline is not defined."; + CHECK(annotations->count(attr::software_pipeline_order)) + << "ValueError: Order of the statement in the software pipeline is not defined."; + Integer stage = Downcast(annotations->at(attr::software_pipeline_stage)); + Integer order = Downcast(annotations->at(attr::software_pipeline_order)); + annotations->erase(attr::software_pipeline_stage); + annotations->erase(attr::software_pipeline_order); + return {static_cast(stage->value), static_cast(order->value)}; + } + + /*! + * \brief Check the pipeline satisfies the following conditions: + * 1) No conflicting order: The order of each statement should be unique. + * 2) No reordering with the same stage: Statements in the same stage are not allowed to be + * reordered. + */ + void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array& original_order) { + std::unordered_set used_orders; + std::unordered_map stage_max_order; + for (const Block& block : original_order) { + const auto& stmt_info = pipeline_info.at(block); + int stage = stmt_info.stage; + int order = stmt_info.order; + CHECK(!used_orders.count(order)) + << "ValueError: Two statements in the software pipeline cannot have the same order"; + used_orders.insert(order); + CHECK(!stage_max_order.count(stage) || stage_max_order[stage] < order) + << "ValueError: Statements in the same stage of the software pipeline must have " + "increasing order."; + stage_max_order[stage] = order; + } + } + + Stmt VisitStmt_(const ForNode* op) final { + // Step 1: Recursively rewrite the children first. + For for_node = Downcast(StmtExprMutator::VisitStmt_(op)); + bool is_pipeline = HasPipelineAnnotation(op); + if (!is_pipeline) { + return std::move(for_node); + } + // Step 2: Find the body of the pipeline. It can be direct child of the for-loop. If the + // for-loop as BlockRealize as its child, the pipeline body will be the child of the block. + Stmt pipeline_body; + Array pipeline_allocs; + if (const auto* realize = for_node->body.as()) { + const auto& block = realize->block; + for (const auto& buffer : block->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + pipeline_body = block->body; + pipeline_allocs = block->alloc_buffers; + } else { + pipeline_body = for_node->body; + } + + const SeqStmtNode* pipeline_body_seq = pipeline_body.as(); + CHECK(pipeline_body_seq) + << "ValueError: The body of the software pipeline should be SeqStmt, got " + << pipeline_body->GetTypeKey(); + const SeqStmtNode* original_seq = + op->body->IsInstance() + ? op->body.as()->block->body.as() + : op->body.as(); + ICHECK(original_seq); + + // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop should + // be converted into a block. + PipelineInfo pipeline_info; + Array original_order; + + auto f_add_child = [&](const Stmt& child) { + const auto* block_realize = child.as(); + Block block = (block_realize && is_one(block_realize->predicate)) + ? block_realize->block + : MakeBlock(child, buffer_data_to_buffer_); + original_order.push_back(block); + }; + for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { + const auto* nested_block_realize = pipeline_body_seq->seq[i].as(); + if (nested_block_realize && is_one(nested_block_realize->predicate) && + nested_block_realize->block->body->IsInstance()) { + const Block& nested_pipeline_block = nested_block_realize->block; + ICHECK( + nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered + for (const auto& buffer : nested_pipeline_block->alloc_buffers) { + pipeline_allocs.push_back(buffer); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + const auto* nested_seq = nested_pipeline_block->body.as(); + for (size_t j = 0; j < nested_seq->seq.size(); j++) { + f_add_child(nested_seq->seq[j]); + } + } else { + f_add_child(pipeline_body_seq->seq[i]); + } + } + + auto pipeline_stages = + Downcast>(op->annotations.at(attr::software_pipeline_stage)); + auto pipeline_orders = + Downcast>(op->annotations.at(attr::software_pipeline_order)); + CHECK_EQ(pipeline_stages.size(), original_order.size()); + CHECK_EQ(pipeline_orders.size(), original_order.size()); + for (size_t i = 0; i < pipeline_stages.size(); i++) { + PipelineStageOrder stage_order(pipeline_stages[i]->value, pipeline_orders[i]->value); + pipeline_info.emplace(original_order[i], stage_order); + } + // ValidatePipelineBody(pipeline_info, original_order); + + // Step 4: Rewrite the pipeline body. + Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, + pipeline_allocs, GetRef(op), pipeline_info); + + if (const auto* realize = op->body.as()) { + const auto& block = realize->block; + for (const auto& buffer : block->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + } + return pipeline; + } + + /*! + * \brief Add buffer allocations to a block and update the write region of the block. + * \param n The block pointer to which the buffer allocations are added. + * \param alloc_buffers The buffer allocations to be added. + */ + void AddAllocBuffers(BlockNode* n, const Array alloc_buffers) { + for (const Buffer& alloc_buffer : alloc_buffers) { + n->alloc_buffers.push_back(alloc_buffer); + Region region; + region.reserve(alloc_buffer->shape.size()); + for (const PrimExpr& dim : alloc_buffer->shape) { + region.push_back(Range::FromMinExtent(0, dim)); + } + n->writes.push_back(BufferRegion(alloc_buffer, region)); + } + } + + /*! + * \brief Flatten nested SeqStmt while passing through BlockRealize / Block. + * \param block The block which has SeqStmt body to rewrite. + * \return The new block that contains flattened SeqStmt as its body. + */ + Block FlattenNestedBlocks(Block block) { + const SeqStmtNode* seq = block->body.as(); + auto* n = block.CopyOnWrite(); + Array new_seq; + new_seq.reserve(seq->seq.size()); + bool changed = false; + for (size_t i = 0; i < seq->seq.size(); i++) { + const auto* nested_block_realize = seq->seq[i].as(); + if (!nested_block_realize || !is_one(nested_block_realize->predicate) || + !nested_block_realize->block->body->IsInstance()) { + new_seq.push_back(seq->seq[i]); + continue; + } + AddAllocBuffers(n, nested_block_realize->block->alloc_buffers); + const auto* nested_seq = nested_block_realize->block->body.as(); + new_seq.reserve(new_seq.size() + nested_seq->seq.size()); + for (const auto& nested_seq_body : nested_seq->seq) { + new_seq.push_back(nested_seq_body); + } + changed = true; + } + if (changed) { + n->body = SeqStmt(new_seq); + } + return block; + } + + Stmt VisitStmt_(const BlockNode* op) final { + for (const auto& buffer : op->alloc_buffers) { + ICHECK(buffer->IsInstance()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + + auto it = op->annotations.find(attr::double_buffer_scope); + if (it != op->annotations.end()) { + int buffer_index = Downcast((*it).second); + CHECK(buffer_index >= 0 && static_cast(buffer_index) < op->writes.size()) + << "ValueError: Index of the buffer exceeds the size of the write regions of the block. (" + << buffer_index << " vs. " << op->writes.size() << ")"; + double_buffers.insert(op->writes[buffer_index]->buffer); + } + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + + // if (block->body->IsInstance()) { + // // Rewriting for software pipelining will produce nested SeqStmt. These statements need to + // be + // // flattened for rewriting outer software pipeline (if nested software pipelines are + // present). block = FlattenNestedBlocks(block); + // } + + for (const auto& buffer : op->alloc_buffers) { + buffer_data_to_buffer_.erase(buffer->data); + } + return block; + } + + bool HasPipelineAnnotation(const ForNode* op) const { + auto it1 = op->annotations.find(attr::software_pipeline_stage); + auto it2 = op->annotations.find(attr::software_pipeline_order); + bool has_stage = it1 != op->annotations.end(); + bool has_order = it2 != op->annotations.end(); + if (has_stage && has_order) { + return true; + } + if (has_stage) { + LOG(FATAL) << "ValueError: Order of the software pipeline is not defined."; + } + if (has_order) { + LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined."; + } + return false; + } + + Map buffer_data_to_buffer_; + std::unordered_set double_buffers; +}; + +} // namespace software_pipeline + +namespace transform { + +/*! + * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. + * \return The IR transform pass. + */ +Pass InjectSoftwarePipeline() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* fptr = f.CopyOnWrite(); + fptr->body = software_pipeline::PipelineInjector::Inject(f); + fptr->body = ConvertSSA(std::move(fptr->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 2eea869af516..d2dd95a581ed 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -149,22 +149,22 @@ Array RemoveBufferFromBufferRegions(const Array& buf /*! * \brief Substitute a given source buffer with a given target buffer in statements or expressions */ -class BufferReplacer : private StmtExprMutator { +class BufferMutator : private StmtExprMutator { public: static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { - return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + return BufferMutator(src_buffer, tgt_buffer)(std::move(stmt)); } private: - explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) + explicit BufferMutator(Buffer src_buffer, Buffer tgt_buffer) : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} - PrimExpr VisitExpr_(const BufferLoadNode* load) final { + PrimExpr VisitExpr_(const BufferLoadNode* load) override { return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) : GetRef(load); } - Stmt VisitStmt_(const BufferStoreNode* store) final { + Stmt VisitStmt_(const BufferStoreNode* store) override { if (store->buffer.same_as(src_buffer_)) { PrimExpr value = StmtExprMutator::VisitExpr(store->value); return BufferStore(tgt_buffer_, value, {0}); @@ -287,7 +287,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionalwrites = {it_buffer_region.value()}; new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = - BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + BufferMutator::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); new_block->init = NullOpt; ObjectPtr n = make_object(*realize); n->block = Block(new_block); @@ -582,10 +582,12 @@ class CrossThreadReductionTransformer : public StmtMutator { PrimExpr combiner_rhs{nullptr}; std::tie(n_bound_reduction_loops, reducer, combiner_rhs) = CheckCanApplyCrossThreadReduction(block, reduction_loops); - // Step 3. When not all the reduction-related loops are bound to thread axes, in-thread - // reduction is needed in this cross-thread reduction. + // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when + // - not all the reduction-related loops are bound to thread axes, or + // - the block-realize has a non-constant-true predicate. bool need_in_thread_reduction = - n_bound_reduction_loops < static_cast(reduction_loops.size()); + n_bound_reduction_loops < static_cast(reduction_loops.size()) || + !is_one(realize->predicate); // Step 4. Create intermediate buffers, storing them in `ct_buffer` and // `it_buffer`. Let the scope block allocate these new buffers. std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc new file mode 100644 index 000000000000..7925f4e090c4 --- /dev/null +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../../runtime/thread_storage_scope.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Fuse consecutive loops + * \param body the outer-most loop + * \return the fused loop + */ +Stmt FuseNestLoops(Stmt body) { + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + std::string suffix; + int n = loops.size(); + for (int i = 1; i < n; i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Map subst_map; + PrimExpr tot = fused_var; + for (int i = n - 1; i >= 0; i--) { + subst_map.Set(loops[i]->loop_var, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + auto f_substitute = [&](const Var& v) -> Optional { + return subst_map.Get(v).value_or(v); + }; + PrimExpr fused_extent = 1; + for (int i = 0; i < n; i++) { + fused_extent *= loops[i]->extent; + } + return For(fused_var, 0, fused_extent, ForKind::kSerial, + Substitute(std::move(body), f_substitute)); +} + +/*! + * \brief a combination of split, bind, vectorize, + * a helper function to perform coalesced load/store + * \param stmt the stmt to do transformation + * \param constraints The constraints, including thread extents, vector bytes, and data bits. + * \return The stmt after transformation + */ +Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { + const ForNode* loop = TVM_TYPE_AS(loop, stmt, ForNode); + int loop_extent = Downcast(loop->extent)->value; + int vector_bytes = constraints.vector_bytes; + int data_bits = constraints.data_bits; + int vector_len = std::max(1, vector_bytes * 8 / data_bits); + int tot_threads = 1; + // generate thread binding loops + std::vector factors{-1}; + std::vector thread_axis; + if (Optional o_t = constraints.thread_extent.Get("threadIdx.z")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.z"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.y")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.y"); + } + if (Optional o_t = constraints.thread_extent.Get("threadIdx.x")) { + int t = o_t.value()->value; + tot_threads *= t; + factors.push_back(t); + thread_axis.push_back("threadIdx.x"); + } + // generate vectorized loop + factors.push_back(vector_len); + // generate outer loop + ICHECK_EQ(loop_extent % (tot_threads * vector_len), 0); + factors[0] = loop_extent / (tot_threads * vector_len); + // create new loop vars + int n = factors.size(); + std::vector new_loop_vars; + new_loop_vars.reserve(n); + for (int i = 0; i < n; i++) { + new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); + } + // substitute fused loop var with new loop vars + PrimExpr substitute_value = 0; + for (int i = 0; i < n; i++) { + substitute_value *= factors[i]; + substitute_value += new_loop_vars[i]; + } + // Construct the new loop nest + Stmt body = Substitute(loop->body, [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }); + body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); + for (int i = n - 2; i >= 1; i--) { + body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), + IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + } + return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); +} + +Stmt CoalescedAccess::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_fuse = FuseNestLoops(stmt); + Stmt after_split = SplitBindVectorize(std::move(after_fuse), constraints); + return after_split; +} + +/*! + * \brief Get the index mapping of a specific stmt. + * The stmt is like: + * for i0: + * ... + * for in: + * A[f(i0, ..., in])] = B[i0, ..., in], + * where f is the index mapping we want to get. + * \param constraints The constraints, including the write region that is required to calculate + * the index mapping + * \return The mapping in the form of j0, ..., jm, where j0, ... jm = f(i0, ..., in) + */ +Array GetMapping(const Stmt& stmt, const ConstraintSet& constraints) { + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + body = loop->body; + } + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + BufferRegion write_region = constraints.write_region; + const Array& write_index = buf_store->indices; + ICHECK(write_region->region.size() == write_index.size() && + write_region->buffer.same_as(buf_store->buffer)); + Array result; + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(write_region->region.size()); i++) { + PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); + if (!is_zero(pattern)) { + result.push_back(pattern); + } + } + return result; +} + +Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body = stmt; + Map var_range; + Array loop_vars; + // Step 1. Get index mapping + Array mapping_pattern = GetMapping(stmt, constraints); + while (const ForNode* loop = body.as()) { + var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loop_vars.push_back(loop->loop_var); + body = loop->body; + } + // Step 2. Get Inverse mapping + arith::Analyzer analyzer; + DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); + Array iter_map = + arith::DetectIterMap(mapping_pattern, var_range, Bool(true), true, &analyzer, diag_ctx); + CHECK_EQ(iter_map.size(), loop_vars.size()); + Map inverse_mapping = arith::InverseAffineIterMap(iter_map, loop_vars); + // Step 3. Generate new body + BufferRegion read_region = constraints.read_region; + BufferRegion write_region = constraints.write_region; + Array write_index; + Array read_index; + Array new_loop_vars; + Map substitute_map; + // Step 3.1 construct target buffer indices + for (int i = 0, j = 0; i < static_cast(write_region->region.size()); i++) { + if (is_one(write_region->region[i]->extent)) { + write_index.push_back(write_region->region[i]->min); + } else { + Var var = runtime::Downcast(loop_vars[j]).copy_with_suffix("_inverse"); + new_loop_vars.push_back(var); + substitute_map.Set(runtime::Downcast(loop_vars[j++]), var); + write_index.push_back(write_region->region[i]->min + var); + } + } + // Step 3.2 construct source buffer indices + for (int i = 0, j = 0; i < static_cast(read_region->region.size()); i++) { + if (is_one(read_region->region[i]->extent)) { + read_index.push_back(read_region->region[i]->min); + } else { + read_index.push_back( + read_region->region[i]->min + + Substitute(inverse_mapping[Downcast(loop_vars[j++])], substitute_map)); + } + } + BufferLoad new_buf_load = BufferLoad(read_region->buffer, read_index); + BufferStore new_buf_store = BufferStore(write_region->buffer, new_buf_load, write_index); + Stmt ret = new_buf_store; + // Step 3.3 construct loop body + for (int i = static_cast(new_loop_vars.size()) - 1; i >= 0; i--) { + PrimExpr extent = write_region->region[i]->extent; + ret = For(new_loop_vars[i], 0, extent, ForKind::kSerial, std::move(ret)); + } + return ret; +} +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_intermediate_stage.cc b/src/tir/transforms/memhammer_intermediate_stage.cc new file mode 100644 index 000000000000..4ffffc9fdeab --- /dev/null +++ b/src/tir/transforms/memhammer_intermediate_stage.cc @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, + Stmt* ith_loop = nullptr) { + Stmt ret = inner_body; + for (int i = static_cast(loops.size() - 1); i >= 0; i--) { + ObjectPtr new_loop = make_object(*loops[i]); + new_loop->body = ret; + ret = For(new_loop); + if (ith == i) { + *ith_loop = ret; + } + } + return ret; +} + +/*! + * \brief lift all the thread binding loops + * \param stmt the top loop + * \return a pair. The first is the transformed stmt. + * The second is the lowest thread binding loop. + */ +std::pair> LiftThreadBindingLoops(Stmt stmt) { + std::vector normal_loops; + std::vector thread_binding_loops; + Stmt body = stmt; + while (const ForNode* loop = body.as()) { + if (loop->kind == ForKind::kThreadBinding) { + thread_binding_loops.push_back(loop); + } else { + normal_loops.push_back(loop); + } + body = loop->body; + } + body = CopyLoopChain(normal_loops, std::move(body)); + For compute_location{nullptr}; + body = CopyLoopChain(thread_binding_loops, std::move(body), + static_cast(thread_binding_loops.size()) - 1, &compute_location); + return std::make_pair(body, compute_location); +} + +/*! + * \brief Analyze the access pattern for buffer rank promotion. + * Rank promotion is a transformation that reshapes the buffer + * but doesn't change its underlying data layout. + * After the reshape, we expect that all dimensions of the access indices + * will be in the form of floormod(floordiv(x, a), b). + * Rank promotion removes strided access, thus enabling further buffer compacting + */ +class IndexPatternFinder : public ExprVisitor { + public: + IndexPatternFinder(const Map& var_range, Array* resulting_index) + : var_range_(var_range), resulting_index_(resulting_index) {} + + /*! + * \brief Calculate the new buffer shape after rank promotion. + * For each dimension of original shape, it will be split into multiple parts. + * The inner array represents the multiple parts of one original dimension, + * and the outer array represents the original dimensions + * For example, original shape [4, 8] may be split into [[2, 2], [2, 4]] + * \param indices The access indices of the buffer + * \param var_range The iter range of the vars in the indices + * \param rewrite_indices The access indices after rank promotion + * \return The new buffer shape after rank promotion. + */ + static Array> GetRankPromotedShape(Array indices, + const Map& var_range, + Array* rewrite_indices) { + Map var_dom = AsIntSet(var_range); + Array> new_shape; + for (const PrimExpr& expr : indices) { + IndexPatternFinder extractor(var_range, rewrite_indices); + arith::IntSet intset = arith::EvalSet(expr, var_dom); + extractor.mod_ = intset.max() + 1; + extractor.div_ = 1; + extractor.offset_ = 0; + extractor(expr); + Array access_shape = extractor.access_shape_; + for (int i = static_cast(access_shape.size()) - 1; i >= 1; i--) { + if (!is_zero(floormod(extractor.offset_, access_shape[i]))) { + return {}; + } else { + extractor.offset_ = floordiv(extractor.offset_, access_shape[i]); + } + } + access_shape.Set(0, extractor.offset_ + access_shape[0]); + new_shape.push_back(access_shape); + } + return new_shape; + } + + private: + void VisitExpr_(const VarNode* op) final { + arith::Analyzer analyzer; + PrimExpr extent = var_range_[GetRef(op)]->extent; + PrimExpr access_iter_range = min(mod_, (max(1, floordiv(extent, div_)))); + if (!analyzer.CanProveEqual(1, access_iter_range)) { + access_shape_.push_back(access_iter_range); + resulting_index_->push_back(floormod(floordiv(GetRef(op), div_), mod_)); + } + } + + void VisitExpr_(const FloorDivNode* op) final { + PrimExpr old_div = div_; + div_ *= op->b; + ExprVisitor::VisitExpr_(op); + div_ = old_div; + } + + void VisitExpr_(const FloorModNode* op) final { + PrimExpr old_mod = mod_; + mod_ = max(1, min(floordiv(op->b, div_), mod_)); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + } + + void VisitExpr_(const MulNode* op) final { + PrimExpr old_mod = mod_; + PrimExpr old_div = div_; + div_ = max(1, floordiv(div_, op->b)); + mod_ = max(1, floordiv(mod_, floordiv(op->b, floordiv(old_div, div_)))); + ExprVisitor::VisitExpr_(op); + mod_ = old_mod; + div_ = old_div; + } + + void VisitExpr_(const AddNode* op) final { + if (is_const_int(op->b)) { + offset_ += floormod(floordiv(op->b, div_), mod_); + } + ExprVisitor::VisitExpr_(op); + } + + PrimExpr div_; + PrimExpr mod_; + PrimExpr offset_; + Map var_range_; + Array access_shape_; + Array* resulting_index_; +}; + +/*! + * \brief Utilities to perform rank promotion + */ +class RankPromoter : public StmtExprMutator { + public: + /*! + * \brief Flatten the buffer shape like performing inverse rank promotion. + * For example, [[i0, i1], [j0, j1]] to [i0 * i1, j0 * j1] + * \param new_shape The buffer shape in the special form as returned by GetRankPromotedShape + * \return The buffer shape after flatten + */ + static Array FlattenNewShape(const Array>& new_shape) { + Array ret; + ret.reserve(new_shape.size()); + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr prod = 1; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + prod *= new_shape[i][j]; + } + ret.push_back(prod); + } + return ret; + } + /** + * \brief Rewrite the index given the shape after rank promotion + * \param indices The original indices + * \param new_shape The buffer shape after rank promotion + * \return The new indices + */ + static Array RewriteIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + ICHECK_EQ(indices.size(), new_shape.size()); + for (int i = 0; i < static_cast(indices.size()); i++) { + PrimExpr index = indices[i]; + // The indices transformed from one original dimension + Array index_dim(new_shape[i].size(), 0); + for (int j = static_cast(new_shape[i].size()) - 1; j >= 0; j--) { + index_dim.Set(j, floormod(index, new_shape[i][j])); + index = floordiv(index, new_shape[i][j]); + } + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + new_indices.push_back(index_dim[j]); + } + } + return new_indices; + } + /*! + * \brief Rewrite the index after buffer flattening + * \param indices The original indices + * \param new_shape The shape before buffer flattening + * \return The indices after buffer flattening + */ + static Array RewriteBackIndex(const Array& indices, + const Array>& new_shape) { + Array new_indices; + int offset = 0; + for (int i = 0; i < static_cast(new_shape.size()); i++) { + PrimExpr index = 0; + for (int j = 0; j < static_cast(new_shape[i].size()); j++) { + index *= new_shape[i][j]; + index += indices[offset + j]; + } + new_indices.push_back(index); + offset += new_shape[i].size(); + } + return new_indices; + } + RankPromoter(const Buffer& src, const Buffer& dst, const Array>& new_shape, + const Array>& relaxed_new_shape, const Array& relaxed_region) + : src_(src), + dst_(dst), + new_shape_(new_shape), + relaxed_new_shape_(relaxed_new_shape), + relaxed_region_(relaxed_region) {} + + static Stmt RewriteBody(Stmt stmt, const Buffer& src, const Buffer& dst, + const Array>& new_shape, + const Array>& relaxed_new_shape, + const Array& relaxed_region) { + RankPromoter promoter(src, dst, new_shape, relaxed_new_shape, relaxed_region); + return promoter(stmt); + } + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + new_store->indices = ConvertIndices(new_store->indices); + return BufferStore(new_store); + } + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + new_load->indices = ConvertIndices(new_load->indices); + return BufferLoad(new_load); + } + return std::move(load); + } + + /*! + * \brief Rewrite the indices after performing buffer rank promotion + + * buffer compacting + buffer flattening. + * \param indices The original indices + * \return The indices after these transformations + */ + Array ConvertIndices(const Array& indices) { + Array rewrite_indices = RewriteIndex(indices, new_shape_); + arith::Analyzer analyzer; + for (int i = 0; i < static_cast(rewrite_indices.size()); i++) { + rewrite_indices.Set(i, analyzer.Simplify(rewrite_indices[i] - relaxed_region_[i]->min)); + } + return RewriteBackIndex(rewrite_indices, relaxed_new_shape_); + } + + const Buffer& src_; + const Buffer& dst_; + Array> new_shape_; + Array> relaxed_new_shape_; + Array relaxed_region_; +}; + +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer) { + Stmt body = stmt; + std::vector loops; + bool need_relax = !compute_location.defined(); + Map relax_var_range; + Map all_var_range; + PrimExpr vector_bytes = -1; + // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into + // several contiguous-changing dimensions + // Step 1.1 collect loop var range for rank promotion + while (const ForNode* loop = body.as()) { + if (need_relax) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else { + loops.push_back(loop); + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + if (loop == compute_location.get()) { + need_relax = true; + } + if (loop->kind == ForKind::kVectorized) { + vector_bytes = loop->extent; + } + body = loop->body; + } + for (const For& loop : outer_loops) { + if (loop->kind == ForKind::kThreadBinding) { + const String& thread_tag = loop->thread_binding.value()->thread_tag; + if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope), + runtime::ThreadScope::Create(thread_tag))) { + relax_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + } + all_var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer orig_buffer = is_write_cache ? buf_store->buffer : buf_load->buffer; + Array indices = is_write_cache ? buf_store->indices : buf_load->indices; + // Step 1.2 get the new shape and new access indices after rank promotion + Array rewrite_indices; + Array> new_shape = + IndexPatternFinder::GetRankPromotedShape(indices, all_var_range, &rewrite_indices); + // Step 2. relax the access region after rank promotion + arith::Analyzer analyzer; + analyzer.Bind(all_var_range); + Array relaxed_region; + relaxed_region.reserve(rewrite_indices.size()); + { + Map relax_var_intset = AsIntSet(relax_var_range); + for (const PrimExpr& index : rewrite_indices) { + arith::IntSet int_set = arith::EvalSet(index, relax_var_intset); + relaxed_region.push_back(Range::FromMinExtent( + int_set.min(), analyzer.Simplify(int_set.max() - int_set.min() + 1))); + } + } + // Step 3. generate the data copy bodies + // preparation work + Array new_loop_vars; + Array orig_buf_indices, new_buf_indices; + Array> relaxed_new_shape; + for (int i = 0; i < static_cast(relaxed_region.size()); i++) { + Var new_loop_var = Var("ax" + std::to_string(i)); + new_loop_vars.push_back(new_loop_var); + orig_buf_indices.push_back(relaxed_region[i]->min + new_loop_var); + new_buf_indices.push_back(new_loop_var); + } + relaxed_new_shape.reserve(new_shape.size()); + for (int i = 0, ct = 0; i < static_cast(new_shape.size()); i++) { + Array layer; + for (int j = 0; j < static_cast(new_shape[i].size()); j++, ct++) { + layer.push_back(relaxed_region[ct]->extent); + } + relaxed_new_shape.push_back(layer); + } + // Step 3.1 create a buffer for the cache + Buffer new_buffer = WithScope(orig_buffer, storage_scope); + new_buffer.CopyOnWrite()->shape = RankPromoter::FlattenNewShape(relaxed_new_shape); + *alloc_buffer = new_buffer; + Array real_orig_buf_indices = + RankPromoter::RewriteBackIndex(orig_buf_indices, new_shape); + Array real_new_buf_indices = + RankPromoter::RewriteBackIndex(new_buf_indices, relaxed_new_shape); + // Step 3.2 generate a body that writes to the cache + Stmt generate_body = is_write_cache + ? BufferStore(orig_buffer, BufferLoad(new_buffer, real_new_buf_indices), + real_orig_buf_indices) + : BufferStore(new_buffer, BufferLoad(orig_buffer, real_orig_buf_indices), + real_new_buf_indices); + for (int i = static_cast(relaxed_region.size()) - 1; i >= 0; i--) { + if (i == static_cast(relaxed_region.size()) - 1 && !is_const_int(vector_bytes, -1)) { + ICHECK(analyzer.CanProve(vector_bytes == relaxed_region[i]->extent)); + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kVectorized, generate_body); + } else { + generate_body = + For(new_loop_vars[i], 0, relaxed_region[i]->extent, ForKind::kSerial, generate_body); + } + } + // Step 3.3 rewrite the original body to load from cache + Stmt rewrite_body; + if (compute_location.defined()) { + rewrite_body = compute_location.value()->body; + } else { + rewrite_body = stmt; + } + rewrite_body = RankPromoter::RewriteBody(rewrite_body, orig_buffer, new_buffer, new_shape, + relaxed_new_shape, relaxed_region); + SeqStmt insert_location; + if (is_write_cache) { + generate_body = insert_location = SeqStmt({rewrite_body, generate_body}); + } else { + generate_body = insert_location = SeqStmt({generate_body, rewrite_body}); + } + generate_body = CopyLoopChain(loops, generate_body); + return std::make_pair(generate_body, insert_location); +} + +Stmt CreateLocalStage::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body; + Optional compute_location; + std::tie(body, compute_location) = LiftThreadBindingLoops(std::move(stmt)); + Buffer cache_buffer; + Stmt after_caching = InsertCacheStage(body, false, "local", compute_location, + constraints.outer_loops, &cache_buffer) + .first; + output->alloc_buffer.push_back(cache_buffer); + return after_caching; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc new file mode 100644 index 000000000000..a0103aab380b --- /dev/null +++ b/src/tir/transforms/memhammer_lower_auto_copy.cc @@ -0,0 +1,763 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../schedule/utils.h" +#include "./ir_utils.h" +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +// rewrite rules +static InverseMapping inverse_mapping; +static CoalescedAccess coalesced_access; +static CreateLocalStage create_local_stage; +static SharedToWmma shared_to_wmma; +static WmmaToGlobal wmma_to_global; +static WmmaToShared wmma_to_shared; + +/*! + * \brief A class to perform auto padding. + * + * One simple way to perform auto padding is to fix each padding size for each dimension at the + * same time, calculate the precise access index and the bank conflict, + * and choose the one with minimal conflict. However, this algorithm has exponential complexity. + * Suppose we have d dimensions and the padding size is 0-31, we need to calculate bank + * conflict for 32^{d-1} times. + * We propose a fast incremental algorithm that works for affine inputs, and it only calculate + * bank conflict for 32*{d-1} times. To be specific, we first decide the optimal padding size for + * dimension d-2, then for dimension d-3, ..., finally for dimension 0. It involves 2 steps. + * + * First, we analyze how a typical warp accesses the shared memory banks. + * A typical warp means setting all irrelevant loop vars to 0, and only keeps the threads in a warp. + * For each dimension, the access index is represented by + * x_1 * scale_1 + ... + x_n * scale_n (x_i is loop var) + * Note: The affine property guarantees that {x_i} must be independent, + * otherwise the algorithm is wrong. + * We will use this information to keep a list for each dimension called "iteration space" that + * records the resulting index as x_i takes each possible value. + * + * For example, the index is [outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is threadIdx.y. + * tx is in [0, 16), and ty is in [0, 2). + * We will first get a warp access [ty, tx*4] because outer and vec are irrelevant loop vars. + * It's obvious that ty, tx*4 are both in the form of x_1 * scale_1 + ... + x_n * scale_n. + * In this case, we will keep lists {{0, 1}, {0, 4, ..., 60}} + * + * Next, we choose a padding size that has minimal conflict from the last dimension to first one. + * To calculate the conflict, we calculate the Cartesian product of the iteration space of all + * dimensions not higher than this. Each single point of product space represents access index + * of a particular thread, by which we can calculate the accessed memory bank. The conflict is + * the highest access frequency among the banks. + * + */ +class AutoPadder { + public: + /** + * \brief Do padding to the given buffers in shard memory + * \param buffers the given buffers + * \return the list of new padded buffers + */ + Array PadSharedMemory(const Array& buffers) { + Array result; + + for (const Buffer& buffer : buffers) { + runtime::StorageScope scope = runtime::StorageScope::Create(buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + auto iter_spaces = iter_spaces_[buffer.get()]; + if (iter_spaces.empty()) { + result.push_back(buffer); + continue; + } + // The access index represented by points in the cartesian product of lower dimension + // iteration spaces + std::vector> low_dim_iter_space(iter_spaces.size(), std::vector()); + + int n = buffer->shape.size(); + int data_bits = buffer->dtype.bits(); + // Step 1. initialize `low_dim_iter_space` with the iteration space of the last dim + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { + auto last_dim_iter_space = iter_spaces[i][n - 1]; + low_dim_iter_space[i] = last_dim_iter_space; + } + PrimExpr stride = 1; + Array reverse_strides; + int pad_min = padding_min_.Get(buffer).value_or(Integer(1)); + // Step 2. For each dimension, select a padding that has minimal bank conflict + for (int k = n - 2; k >= 0; k--) { // dims + int max_pad_size = std::min( + int(max_pad_factor_ * (stride * buffer->shape[k + 1]).as()->value), + 32 * 32 / data_bits); + int min_conflict = INT32_MAX; + int min_conflict_pad = -1; + for (int pad = 0; pad <= max_pad_size; pad += pad_min) { // select padding + int padded_stride = ((stride * buffer->shape[k + 1]).as()->value + pad) % + (32 * 32 / data_bits); + int conflict = 0; + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + int bank[32]{0}; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + int comb = (v1 * padded_stride + v2) * data_bits / 32 % 32; + bank[comb]++; + } + } + for (int j = 0; j < 32; j++) { + conflict = std::max(conflict, bank[j]); + } + } + if (conflict < min_conflict) { + min_conflict = conflict; + min_conflict_pad = pad; + } + } + // update low_dim_iter_space with + for (int i = 0; i < static_cast(iter_spaces.size()); i++) { // accesses + auto iter_space = iter_spaces[i][k]; + if (!iter_space.empty()) { + int padded_stride = + ((stride * buffer->shape[k + 1]).as()->value + min_conflict_pad) % + (32 * 32 / data_bits); + std::vector span; + for (int v1 : iter_space) { + for (int v2 : low_dim_iter_space[i]) { + span.push_back(((v1 * padded_stride + v2) * data_bits) % (32 * 32 / data_bits)); + } + } + low_dim_iter_space[i] = span; + } else { + ICHECK(min_conflict_pad == 0); + } + } + stride = stride * buffer->shape[k + 1] + min_conflict_pad; + reverse_strides.push_back(stride); + } + // Step 3. create the new padded buffer + ObjectPtr b = make_object(*buffer.get()); + Array strides; + for (int i = static_cast(reverse_strides.size()) - 1; i >= 0; i--) { + strides.push_back(reverse_strides[i]); + } + strides.push_back(1); + b->strides = strides; + Buffer new_buffer(b); + result.push_back(new_buffer); + padded_buffer_map_.Set(buffer, new_buffer); + } else { + result.push_back(buffer); + } + } + return result; + } + + /** + * \brief Replace all occurrence of the old buffer with the new buffer in the stmt + * \param stmt the stmt to do replacement + * \return the stmt after replacement + */ + Stmt RewriteBufferAccess(const Stmt& stmt) { + class Rewriter : public StmtExprMutator { + public: + Rewriter(const Map& buffer_map) : buffer_map_(buffer_map) {} + + private: + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(load); + } + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + if (buffer_map_.count(op->buffer)) { + op->buffer = buffer_map_[op->buffer]; + } + return std::move(store); + } + + Stmt VisitStmt_(const BlockNode* op) final { + // To reduce the number of blocks in block sref reuse map, we check whether the block is + // really mutated (i.e., the old buffer appears in the block). If so, we return the block + // after mutation. Otherwise we just return the original block. + bool changed = false; + // Step 1. Mutate the read region. + Array reads; + for (const BufferRegion& read : op->reads) { + if (buffer_map_.count(read->buffer)) { + changed = true; + reads.push_back(BufferRegion(buffer_map_[read->buffer], read->region)); + } else { + reads.push_back(read); + } + } + // Step 2. Mutate the write region. + Array writes; + for (const BufferRegion& write : op->writes) { + if (buffer_map_.count(write->buffer)) { + changed = true; + writes.push_back(BufferRegion(buffer_map_[write->buffer], write->region)); + } else { + writes.push_back(write); + } + } + // Step 4. Mutate `match_buffers`. If an old buffer appears as a source of + // MatchBufferRegion, the storage scope of the target buffer also needs to be set. + Array match_buffers; + for (const MatchBufferRegion& match_buffer : op->match_buffers) { + if (buffer_map_.count(match_buffer->source->buffer)) { + changed = true; + Buffer new_buffer = buffer_map_[match_buffer->source->buffer]; + match_buffers.push_back(MatchBufferRegion( + match_buffer->buffer, BufferRegion(new_buffer, match_buffer->source->region))); + } else { + match_buffers.push_back(match_buffer); + } + } + // Step 5. Recursively mutate the block. + Stmt res = StmtMutator::VisitStmt_(op); + if (res.get() != op) { + changed = true; + } + + if (changed) { + ObjectPtr block = CopyOnWrite(res.as()); + block->reads = std::move(reads); + block->writes = std::move(writes); + block->match_buffers = std::move(match_buffers); + return Stmt(block); + } else { + return GetRef(op); + } + } + const Map& buffer_map_; + }; + Rewriter rewriter(padded_buffer_map_); + return rewriter(stmt); + } + + /** + * \brief an equivalent of scale * loop_var with loop_var: {min=0, extent=extent} + */ + struct Pattern { + int extent; + int scale; + }; + + /** + * \brief Collect pattern from indices + */ + class PatternCollector : public StmtExprVisitor { + void VisitExpr_(const VarNode* op) final { + if (!success_) { + return; + } + int extent = var_range_[GetRef(op)]->extent.as()->value; + if (extent > 1) { + stack_.push({{extent, 1}}); + } else { + stack_.push({}); + } + } + + void VisitExpr_(const AddNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector merged_patterns; + std::vector r = stack_.top(); + stack_.pop(); + std::vector l = stack_.top(); + stack_.pop(); + for (const Pattern& pattern : l) { + merged_patterns.push_back(pattern); + } + for (const Pattern& pattern : r) { + merged_patterns.push_back(pattern); + } + if (merged_patterns.empty()) { + stack_.push({}); + return; + } + std::vector ret; + ret.push_back(merged_patterns[0]); + for (int i = 0; i < static_cast(merged_patterns.size()); i++) { + Pattern prev_pattern = ret.back(); + if (merged_patterns[i].extent * merged_patterns[i].scale == prev_pattern.scale) { + ret.pop_back(); + ret.push_back( + {prev_pattern.extent * merged_patterns[i].extent, merged_patterns[i].scale}); + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorDivNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int lower_factor = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale >= lower_factor) { + if (pattern.scale % lower_factor == 0) { + ret.push_back({pattern.extent, pattern.scale / lower_factor}); + } else { + success_ = false; + } + } else if (pattern.scale * pattern.extent > lower_factor) { + if ((pattern.scale * pattern.extent) % lower_factor == 0) { + ret.push_back({pattern.extent * pattern.scale / lower_factor, 1}); + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const FloorModNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int extent = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + if (pattern.scale < extent) { + if (extent % pattern.scale == 0) { + if (extent / pattern.scale < pattern.extent) { + ret.push_back({extent / pattern.scale, pattern.scale}); + } else { + ret.push_back({pattern.extent, pattern.scale}); + } + } else { + success_ = false; + } + } + } + stack_.push(ret); + } + + void VisitExpr_(const MulNode* op) final { + ExprVisitor::VisitExpr_(op); + if (!success_) { + return; + } + std::vector inner = stack_.top(); + stack_.pop(); + int scale = op->b.as()->value; + std::vector ret; + for (const Pattern& pattern : inner) { + ret.push_back({pattern.extent, pattern.scale * scale}); + } + stack_.push(ret); + } + + public: + PatternCollector(const Map& var_range) : var_range_(var_range) {} + + /*! + * \brief Collect the iteration space for given indices. The iteration space is the possible + * values that an index can take (do not remove duplicate). + * For example, the input is [ty, tx*4], where tx is in [0, 16), and ty is in [0, 2). + * The output would be {{0, 1}, {0, 4, ..., 60}} + * \param indices The indices to analyze + * \param var_range The range of loop variables + * \param data_bits The size of dtype in bits + * \return The iteration space. The first array represents dimensions, and the second array + * represents the iteration space of one dimension + */ + static std::vector> CollectIterationSpace(const Array& indices, + const Map& var_range, + int data_bits) { + PatternCollector collector(var_range); + std::vector> ret; + for (int i = 0; i < static_cast(indices.size()); i++) { + collector(indices[i]); + if (collector.success_ && collector.stack_.size() == 1) { + auto patterns = collector.stack_.top(); + int extent_prod = 1; + for (const Pattern& p : patterns) { + extent_prod *= p.extent; + } + std::vector iter_space; + for (int thread_id = 0; thread_id < extent_prod; thread_id++) { + int index = 0; + int n = thread_id; + for (int j = static_cast(patterns.size()) - 1; j >= 0; j--) { + int val = n % patterns[j].extent; + index += val * patterns[j].scale; + n /= patterns[j].extent; + } + iter_space.push_back(index); + } + + ret.push_back(iter_space); + collector.stack_.pop(); + } else { + ret.push_back({}); + } + } + return ret; + } + + std::stack> stack_; + const Map& var_range_; + bool success_ = true; + }; + + /*! A utility class for calling CollectIterationSpace to each buffer access*/ + class IterSpaceAnalyzer : public StmtExprVisitor { + public: + IterSpaceAnalyzer(const Map& substitute_map, AutoPadder* self, int data_bits, + const Map warp_thread_extent) + : substitute_map_(substitute_map), + self(self), + data_bits_(data_bits), + warp_thread_extent_(warp_thread_extent) {} + + private: + bool CheckVarContiguous(PrimExpr e, Var var) { + PrimExpr e1 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(0); + } else { + return v; + } + }); + PrimExpr e2 = Substitute(e, [var](const Var& v) -> Optional { + if (v.same_as(var)) { + return Integer(1); + } else { + return v; + } + }); + arith::Analyzer analyzer; + return analyzer.CanProve(e2 - e1 == 1); + } + + void VisitStmt_(const ForNode* op) final { + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.Set(op->loop_var, op->min); + } else { + Integer extent = + warp_thread_extent_.Get(op->thread_binding.value()->thread_tag).value_or(1); + var_range_.Set(op->loop_var, Range::FromMinExtent(op->min, extent)); + } + if (op->kind == ForKind::kVectorized) { + vector_var = op->loop_var; + vector_length_ = op->extent.as()->value; + } + StmtExprVisitor::VisitStmt_(op); + if (op->kind == ForKind::kVectorized) { + vector_length_ = -1; + } + if (op->kind != ForKind::kThreadBinding) { + substitute_map_.erase(op->loop_var); + } + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer store + * For example, the access is A[outer*2+ty, tx*4+vec] = xxx, where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer store + */ + void VisitStmt_(const BufferStoreNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && CheckVarContiguous(substitued_indices.back(), vector_var)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitStmt_(op); + } + /*! + * \brief Take a typical warp and collect the iteration space for buffer load + * For example, the access is xxx = A[outer*2+ty, tx*4+vec], where tx is threadIdx.x, and ty is + * threadIdx.y. tx is in [0, 16), and ty is in [0, 2). + * The iteration space would be {{0, 1}, {0, 4, ..., 60}}. + * \param op the buffer load + */ + void VisitExpr_(const BufferLoadNode* op) final { + runtime::StorageScope scope = runtime::StorageScope::Create(op->buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : op->indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = + PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[op->buffer.get()].push_back(iter_space); + } + if (vector_length_ != -1 && CheckVarContiguous(substitued_indices.back(), vector_var)) { + Integer m = self->padding_min_.Get(op->buffer).value_or(1); + self->padding_min_.Set(op->buffer, Downcast(max(vector_length_, m))); + } + } + StmtExprVisitor::VisitExpr_(op); + } + + /*! + * \brief Take a typical warp and collect the iteration space for load_matrix_sync and + * store_matrix_sync + * For example, the access region is A[y*16+16, x*16+16], where y and x are not bound to + * threadIdx. The iteration space would be {{0, 1, ..., 15}, {0, 1, ..., 15}}. + * \param op the call node + */ + void VisitStmt_(const BlockNode* op) final { + if (const auto* eval = op->body.as()) { + if (const auto* call = eval->value.as()) { + if (call->op == builtin::tvm_load_matrix_sync() || + call->op == builtin::tvm_store_matrix_sync()) { + for (const MatchBufferRegion& r : op->match_buffers) { + Buffer src_buffer = r->source->buffer; + runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); + if (scope.rank == runtime::StorageRank::kShared) { + Region region = r->source->region; + Array indices; + for (int i = 0; i < static_cast(region.size()); i++) { + Var var("region" + std::to_string(i)); + indices.push_back(region[i]->min + var); + var_range_.Set(var, Range::FromMinExtent(0, region[i]->extent)); + } + Array substitued_indices; + arith::Analyzer analyzer; + for (const PrimExpr& e : indices) { + substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + } + std::vector> iter_space = PatternCollector::CollectIterationSpace( + substitued_indices, var_range_, data_bits_); + if (!iter_space.empty()) { + self->iter_spaces_[src_buffer.get()].push_back(iter_space); + } + } + } + } + } + } + } + + Map substitute_map_; + AutoPadder* self; + int data_bits_; + Map warp_thread_extent_; + Map var_range_; + int vector_length_ = -1; + Var vector_var; + }; + + /*! + * \brief Analyze the shared memory access + * \param stmt The data copy + * \param outer_loops The outer loops of the stmt + * \param data_bits The length of dtype in bits + * \param thread_extent The extents of all thread binding loops + */ + void AnalyzeSharedMemoryAccess(const Stmt& stmt, const Array& outer_loops, int data_bits, + const Map& thread_extent) { + Map warp_thread_extent; + Integer prod = 1; + Array thread_tags{"threadIdx.x", "threadIdx.y", "threadIdx.z"}; + arith::Analyzer analyzer; + for (int i = 0; i < 3; i++) { + Integer extent = thread_extent.Get(thread_tags[i]).value_or(1); + if (analyzer.CanProve(prod * extent >= 32)) { + warp_thread_extent.Set(thread_tags[i], Downcast(floordiv(32, prod))); + prod *= floordiv(32, prod); + break; + } else { + warp_thread_extent.Set(thread_tags[i], Downcast(extent)); + prod *= extent; + } + } + Map substitute_map; + for (const For& loop : outer_loops) { + substitute_map.Set(loop->loop_var, loop->min); + } + IterSpaceAnalyzer iter_space_analyzer(substitute_map, this, data_bits, warp_thread_extent); + iter_space_analyzer(stmt); + } + + private: + /*! \brief A map from the old buffers to the new padded buffers */ + Map padded_buffer_map_; + /*! \brief A map from each buffer to the iteration spaces of the accesses*/ + std::unordered_map>>> iter_spaces_; + /*! \brief A map from each buffer to their minimal padding size */ + Map padding_min_; + /*! \brief max padding size in relative to the original shape*/ + const double max_pad_factor_ = 0.25; + + friend class AutoCopyMutator; +}; + +class AutoCopyMutator : public StmtExprMutator { + public: + explicit AutoCopyMutator(Map thread_extent) : thread_extent_(thread_extent) {} + /** + * \brief Replace old buffers with padded buffers in the stmt + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ + Stmt RewritePaddingBody(const Stmt& stmt) { return padder.RewriteBufferAccess(stmt); } + + private: + Stmt VisitStmt_(const BlockNode* op) final { + Block block = Downcast(StmtMutator::VisitStmt_(op)); + // only rewrite the block annotated with "auto_copy" + if (GetAnn(op, "auto_copy").value_or(0)->value == 0) { + BlockNode* n = block.CopyOnWrite(); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + ICHECK_EQ(block->reads.size(), 1); + ICHECK_EQ(block->writes.size(), 1); + int data_bits = block->reads[0]->buffer->dtype.bits(); + ConstraintSet constraints(this->thread_extent_, // + this->outer_loops_, // + block->reads[0], // + block->writes[0], // + data_bits, // + block->annotations); + BlockNode* n = block.CopyOnWrite(); + OutputSet outputs; + for (RewriteRule* rule : rules) { + n->body = rule->Apply(std::move(n->body), constraints, &outputs); + } + for (const Buffer& buffer : outputs.alloc_buffer) { + n->alloc_buffers.push_back(buffer); + } + for (const auto& p : outputs.padding_min) { + Integer m = padder.padding_min_.Get(p.first).value_or(1); + padder.padding_min_.Set(p.first, Downcast(max(p.second, m))); + } + padder.AnalyzeSharedMemoryAccess(block->body, outer_loops_, data_bits, thread_extent_); + n->alloc_buffers = padder.PadSharedMemory(std::move(n->alloc_buffers)); + return std::move(block); + } + + Stmt VisitStmt_(const ForNode* op) final { + outer_loops_.push_back(GetRef(op)); + Stmt stmt = StmtMutator::VisitStmt_(op); + outer_loops_.pop_back(); + return stmt; + } + + /*! \brief Thread extents collected. */ + Map thread_extent_; + /*! \brief The outer loops during recursive visit */ + Array outer_loops_; + /*! \brief Calculating optimal padding size */ + AutoPadder padder; + + /*! \brief All rewrite rules. */ + const std::array rules = { + &inverse_mapping, // + &coalesced_access, // + &create_local_stage, // + &shared_to_wmma, // + &wmma_to_global, // + &wmma_to_shared, + }; +}; + +/*! + * \brief Collect the extent for all thread binding loops. + */ +class ThreadExtentCollector : public StmtVisitor { + public: + static Map CollectThreadExtent(const Stmt& stmt) { + ThreadExtentCollector collector; + collector(stmt); + return collector.thread_extent_; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (Optional warp_execution = GetAnn(op, "warp_execution")) { + if (warp_execution.value()->value != 0) { + thread_extent_.Set("threadIdx.x", Integer(32)); + } + } + StmtVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode* op) final { + if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) { + thread_extent_.Set(op->thread_binding.value()->thread_tag, Downcast(op->extent)); + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief the map from thread tag to its extent */ + Map thread_extent_; +}; + +namespace transform { + +Pass LowerAutoCopy() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + AutoCopyMutator mutator(ThreadExtentCollector::CollectThreadExtent(n->body)); + n->body = mutator(std::move(n->body)); + n->body = mutator.RewritePaddingBody(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerAutoCopy", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerAutoCopy").set_body_typed(LowerAutoCopy); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_rewrite_rule.h b/src/tir/transforms/memhammer_rewrite_rule.h new file mode 100644 index 000000000000..1cb0ea496a03 --- /dev/null +++ b/src/tir/transforms/memhammer_rewrite_rule.h @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "../schedule/utils.h" + +namespace tvm { +namespace tir { + +/*! \brief The set containing all possible constraints of a data copy */ +struct ConstraintSet { + /*! \brief The extents of the thread binding loops */ + Map thread_extent; + /*! \brief The outer loops surrounding the data copy */ + Array outer_loops; + /*! \brief The read region of the data copy */ + BufferRegion read_region; + /*! \brief The write region of the data copy */ + BufferRegion write_region; + /*! \brief The dtype size in bits */ + int data_bits; + /*! \brief Whether to insert a local stage in the data copy */ + int add_local_stage = 0; + /*! \brief The vectorization length in bytes */ + int vector_bytes = 1; + + explicit ConstraintSet(Map thread_extent, // + Array outer_loops, // + BufferRegion read_region, // + BufferRegion write_region, // + int data_bits, // + const Map& ann) + : thread_extent(thread_extent), + outer_loops(outer_loops), + read_region(read_region), + write_region(write_region), + data_bits(data_bits) { + if (Optional add_local_stage = ann.Get("local_stage")) { + this->add_local_stage = Downcast(add_local_stage.value())->value; + } + if (Optional vector_bytes = ann.Get("vector_bytes")) { + this->vector_bytes = Downcast(vector_bytes.value())->value; + } + } +}; + +/*! \brief The set containing all possible outputs of a rewrite rule */ +struct OutputSet { + /*! \brief New buffers allocated after rewrite */ + Array alloc_buffer; + /*! \brief The minimal padding size of a buffer in base 2 logarithm */ + Map padding_min; +}; + +/*! + * \brief Rules to rewrite a data copy. + */ +class RewriteRule { + protected: + /* RewriteRule() = default; */ + /*! + * \brief Rewrite the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \param output Some additional information that the rewrite rule produces. (including the new + * buffer to be allocated, etc.) + * \return the stmt after rewrite + */ + virtual Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const = 0; + /*! + * \brief Whether the rewrite rule can be applied to the stmt under certain constraints + * \param stmt The stmt + * \param constraints The constraints of the rewrite + * \return A boolean flag indicating whether the rule can be applied + */ + virtual bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const { return true; } + + public: + inline Stmt Apply(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const { + if (CanApply(stmt, constraints)) { + return Rewrite(stmt, constraints, output); + } else { + return stmt; + } + } +}; + +inline bool IsCopyBetweenScope(const Buffer& src_buffer, const Buffer& tgt_buffer, + runtime::StorageRank src_rank, runtime::StorageRank tgt_rank) { + runtime::StorageScope src_scope = runtime::StorageScope::Create(src_buffer.scope()); + runtime::StorageScope tgt_scope = runtime::StorageScope::Create(tgt_buffer.scope()); + return src_scope.rank == src_rank && tgt_scope.rank == tgt_rank; +} + +/*! + * \brief Coalesce and vectorize memory access. + */ +class CoalescedAccess : public RewriteRule { + public: + CoalescedAccess() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Transform from A[f(i,j)] = B[i,j] to A[i,j] = B[f^{-1}(i,j)] + */ +class InverseMapping : public RewriteRule { + public: + InverseMapping() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Create a local stage when loading from global memory to shared memory. + */ +class CreateLocalStage : public RewriteRule { + public: + CreateLocalStage() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kGlobal, + runtime::StorageRank::kShared) && + is_one(constraints.add_local_stage); + } +}; + +/*! + * \brief Add a cache stage in shared memory. Perform tensor core rewrite for wmma->shared, and + * perform coalescing and vectorizing for shared->global. + */ +class WmmaToGlobal : public RewriteRule { + public: + WmmaToGlobal() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kGlobal); + } +}; + +/*! + * \brief Rewrite shared->wmma data copy with load_matrix_sync + */ +class SharedToWmma : public RewriteRule { + public: + SharedToWmma() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixA) || + IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kShared, + runtime::StorageRank::kWMMAMatrixB); + } +}; + +/*! + * \brief Rewrite wmma->shared data copy with store_matrix_sync + */ +class WmmaToShared : public RewriteRule { + public: + WmmaToShared() = default; + Stmt Rewrite(const Stmt& stmt, const ConstraintSet& constraints, OutputSet* output) const final; + bool CanApply(const Stmt& stmt, const ConstraintSet& constraints) const final { + Buffer src_buffer = constraints.read_region->buffer; + Buffer tgt_buffer = constraints.write_region->buffer; + return IsCopyBetweenScope(src_buffer, tgt_buffer, runtime::StorageRank::kWMMAAccumulator, + runtime::StorageRank::kShared); + } +}; + +/*! + * \brief Insert a cache stage to the compute location + * \param stmt the stmt + * \param is_write_cache whether to write a read cache or write cache + * \param storage_scope the storage scope of the new cache + * \param compute_location the compute location. + * \param outer_loops the outer loops of this stmt + * \param alloc_buffer the new cache block + * \return a pair. The first is the stmt after transformation. + * The second is the SeqStmt that contains 2 stages (one original and another inserted). + */ +std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope, + Optional compute_location, + const Array& outer_loops, Buffer* alloc_buffer); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc b/src/tir/transforms/memhammer_tensorcore_rewrite.cc new file mode 100644 index 000000000000..6e880146d618 --- /dev/null +++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./memhammer_rewrite_rule.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. + * \param stmt The stmt + * \return A pair. The first is the stmt after transformation. + * The second is the compute location where we may add write cache. + */ +std::pair> TileWmmaBlock(Stmt stmt) { + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + PrimExpr extent_last1 = loops[n - 1]->extent; + PrimExpr extent_last2 = loops[n - 2]->extent; + { + arith::Analyzer analyzer; + if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) || + !analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) { + return std::make_pair(stmt, NullOpt); + } + } + Var new_loop_vars[4] = { + /*0:*/ loops[n - 2]->loop_var.copy_with_suffix("_0"), + /*1:*/ loops[n - 1]->loop_var.copy_with_suffix("_0"), + /*2:*/ loops[n - 2]->loop_var.copy_with_suffix("_1"), + /*3:*/ loops[n - 1]->loop_var.copy_with_suffix("_1"), + }; + body = Substitute(std::move(body), + Map{ + {loops[n - 2]->loop_var, new_loop_vars[0] * 16 + new_loop_vars[2]}, + {loops[n - 1]->loop_var, new_loop_vars[1] * 16 + new_loop_vars[3]}, + }); + { + PrimExpr factor[4] = { + /*0:*/ floordiv(extent_last2, 16), // + /*1:*/ floordiv(extent_last1, 16), // + /*3:*/ 16, // + /*4:*/ 16, // + }; + body = For(new_loop_vars[3], 0, factor[3], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[2], 0, factor[2], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[1], 0, factor[1], ForKind::kSerial, std::move(body)); + body = For(new_loop_vars[0], 0, factor[0], ForKind::kSerial, std::move(body)); + } + For compute_location = Downcast(body); + for (int i = n - 3; i >= 0; i--) { + body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, std::move(body), + loops[i]->thread_binding, loops[i]->annotations); + } + return {body, compute_location}; +} + +Array RelaxIndices(const Array& indices, const Array& shape, + const Map& var_dom) { + Array int_set = arith::EvalSet(indices, var_dom); + int ndim = int_set.size(); + Array region; + region.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + region.push_back(int_set[i].CoverRange(Range::FromMinExtent(0, shape[i]))); + }; + return region; +} + +/*! + * \brief Rewrite the data copy that stores to wmma fragment with wmma::load_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaLoad(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(16); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + Buffer new_src_buffer( + /*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer( + /*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, + /*predicate=*/Bool(true), + Block( + /*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_load", + /*body=*/ + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_load_matrix_sync(), + { + /*0:*/ new_tgt_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_tgt_buffer->elem_offset, 256) + + floordiv(floormod(new_tgt_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*dtype=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + /*args=*/ + { + /*0:*/ TypeAnnotation(new_src_buffer->dtype), + /*1:*/ new_src_buffer->data, + /*2:*/ new_src_buffer->elem_offset, + /*3:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2] * 16, + /*4:*/ 1, + }), + /*6:*/ new_src_buffer->strides[new_src_buffer->strides.size() - 2], + /*7:*/ StringImm("row_major"), + })), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + /*0:*/ MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + /*1:*/ MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +/*! + * \brief Rewrite the data copy that loads from wmma fragment with wmma::store_matrix_sync + * \param stmt The stmt to rewrite + * \return The stmt after rewrite + */ +Stmt RewriteWmmaStore(Stmt stmt) { + using arith::IntSet; + const DataType dtype = DataType::Float(32); + const DataType int32 = DataType::Int(32); + + Stmt body = stmt; + std::vector loops; + while (const ForNode* loop = body.as()) { + loops.push_back(loop); + body = loop->body; + } + int n = loops.size(); + + Map var_dom{ + {loops[n - 1]->loop_var, IntSet::FromMinExtent(loops[n - 1]->min, loops[n - 1]->extent)}, + {loops[n - 2]->loop_var, IntSet::FromMinExtent(loops[n - 2]->min, loops[n - 2]->extent)}, + }; + // TODO: the assumption that the RHS of BufferStore is BufferLoad may not be accurate + const BufferStoreNode* buf_store = TVM_TYPE_AS(buf_store, body, BufferStoreNode); + const BufferLoadNode* buf_load = TVM_TYPE_AS(buf_load, buf_store->value, BufferLoadNode); + Buffer src_buffer = buf_load->buffer; + Buffer tgt_buffer = buf_store->buffer; + + Buffer new_src_buffer(/*data=*/Var("src", PointerType(PrimType(dtype), src_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{}, + /*elem_offset=*/Var("src_elem_offset", int32), + /*name=*/"src", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + Buffer new_tgt_buffer(/*data=*/Var("tgt", PointerType(PrimType(dtype), tgt_buffer.scope())), + /*dtype=*/dtype, + /*shape=*/{Integer(16), Integer(16)}, + /*strides=*/{Var("s1", int32), Var("s0", int32)}, + /*elem_offset=*/Var("tgt_elem_offset", int32), + /*name=*/"tgt", + /*data_alignment=*/128, + /*offset_factor=*/16, + /*buffer_type=*/kDefault); + + Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); + Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + + Stmt wmma_body = BlockRealize( + /*iter_values=*/{}, // + /*predicate=*/Bool(true), + Block(/*iter_vars=*/{}, + /*reads=*/{BufferRegion(src_buffer, read_region)}, + /*writes=*/{BufferRegion(tgt_buffer, write_region)}, + /*name_hint=*/"wmma_store", + Evaluate(Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_store_matrix_sync(), + {/*0:*/ new_src_buffer->data, + /*1:*/ 16, + /*2:*/ 16, + /*3:*/ 16, + /*4:*/ floordiv(new_src_buffer->elem_offset, 256) + + floordiv(floormod(new_src_buffer->elem_offset, 256), 16), + /*5:*/ + Call( + /*data=*/runtime::DataType::Handle(), + /*op=*/builtin::tvm_access_ptr(), + { + /*0:*/ TypeAnnotation(new_tgt_buffer->dtype), + /*1:*/ new_tgt_buffer->data, + /*2:*/ new_tgt_buffer->elem_offset, + /*3:*/ new_tgt_buffer->strides[0] * 16, + /*4:*/ 2, + }), + /*6:*/ new_tgt_buffer->strides[0], + /*7:*/ StringImm("row_major")})), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/ + { + MatchBufferRegion(new_src_buffer, BufferRegion(src_buffer, read_region)), + MatchBufferRegion(new_tgt_buffer, BufferRegion(tgt_buffer, write_region)), + }, + /*annotations=*/{})); + for (int i = n - 3; i >= 0; i--) { + wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, loops[i]->kind, + std::move(wmma_body), loops[i]->thread_binding, loops[i]->annotations); + } + return wmma_body; +} + +Stmt SharedToWmma::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.read_region->buffer, 8); + return RewriteWmmaLoad(after_tiling); +} + +Stmt WmmaToShared::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt after_tiling = TileWmmaBlock(stmt).first; + output->padding_min.Set(constraints.write_region->buffer, 8); + return RewriteWmmaStore(after_tiling); +} + +class WmmaToGlobalRewriter : public StmtExprMutator { + public: + WmmaToGlobalRewriter(const SeqStmtNode* tgt_stmt, const ConstraintSet& constraints) + : tgt_stmt_(tgt_stmt), constraints_(constraints) {} + + private: + Stmt VisitStmt_(const SeqStmtNode* op) final { + if (op == tgt_stmt_) { + ICHECK_EQ(op->seq.size(), 2); + Stmt wmma_to_shared = RewriteWmmaStore(op->seq[0]); + Stmt shared_to_global = CoalescedAccess().Rewrite(op->seq[1], constraints_, nullptr); + return SeqStmt({wmma_to_shared, shared_to_global}); + } else { + return StmtMutator::VisitStmt_(op); + } + } + + const SeqStmtNode* tgt_stmt_; + const ConstraintSet& constraints_; +}; + +Stmt WmmaToGlobal::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, + OutputSet* output) const { + Stmt body{nullptr}; + Optional compute_location{nullptr}; + std::tie(body, compute_location) = TileWmmaBlock(stmt); + SeqStmt seq{nullptr}; + Buffer cache_buffer; + // Step 1. add a shared memory cache + std::tie(body, seq) = InsertCacheStage(std::move(body), true, "shared.dyn", compute_location, + constraints.outer_loops, &cache_buffer); + output->alloc_buffer.push_back(cache_buffer); + output->padding_min.Set(cache_buffer, 8); + // Step 2. do coalesced rewrite and tensor core rewrite respectively for 2 parts + WmmaToGlobalRewriter rewriter(seq.get(), constraints); + return rewriter(body); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc new file mode 100644 index 000000000000..dd19d7923e77 --- /dev/null +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file renormalize_split_pattern.cc + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + */ +#include +#include +#include +#include +#include +#include + +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" + +namespace tvm { +namespace tir { + +using namespace arith; + +// macro for doing simple rewrite +#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ + } + +// macro rewrite + recursive_rewrite only if CondExor is true after match. +#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ + } + +class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { + public: + explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + + PrimExpr VisitExpr_(const FloorDivNode* op) final { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + PrimExpr ret = floordiv(a, b); + // Pattern var to match any expression + PVar x, y, z; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + + // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) + TVM_TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), + floormod(floordiv(x, c2), floordiv(c3, c2)), + c3.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floordiv(floormod(x, broadcast(c3, lanes)), broadcast(c2, lanes)), + floormod(floordiv(x, broadcast(c2, lanes)), broadcast(floordiv(c3, c2), lanes)), + c3.Eval()->value % c2.Eval()->value == 0); + + // floordiv(x*c1*c3 + y, c2*c3) = floordiv(x*c1 + floordiv(y, c3), c2) + if ((floordiv(x * c1 + y, c2)).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval(), c3), c2_div)); + } + } + } + if ((floordiv(x * broadcast(c1, lanes) + y, broadcast(c2, lanes))).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv( + x.Eval() * Broadcast(c1_div, lanes.Eval()) + + floordiv(y.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), + Broadcast(c2_div, lanes.Eval()))); + } + } + } + + // floordiv(x*c1*c3 + y + z, c2*c3) = floordiv(x*c1 + floordiv(y + z, c3), c2) + if ((floordiv(x * c1 + y + z, c2)).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval() + z.Eval(), c3), c2_div)); + } + } + } + if ((floordiv(x * broadcast(c1, lanes) + y + z, broadcast(c2, lanes))).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv( + x.Eval() * Broadcast(c1_div, lanes.Eval()) + + floordiv(y.Eval() + z.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), + Broadcast(c2_div, lanes.Eval()))); + } + } + } + + return ret; + } + + PrimExpr VisitExpr_(const LENode* op) { return this->VisitExpr(Not(op->b < op->a)); } + + PrimExpr VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } + + PrimExpr VisitExpr_(const GENode* op) { return this->VisitExpr(Not(op->a < op->b)); } + + PrimExpr VisitExpr_(const LTNode* op) { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + PrimExpr ret = tir::LT(a, b); + // Pattern var to match any expression + PVar x; + // Pattern var match IntImm + PVar c1, c2; + TVM_TRY_RECURSIVE_REWRITE_IF(xvalue> 0); + return ret; + } + + PrimExpr VisitExpr_(const NotNode* op) { + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + // Pattern var to match any expression + PVar x, y; + TVM_TRY_REWRITE(!(!x), x); + TVM_TRY_REWRITE(!(x <= y), y < x); + TVM_TRY_REWRITE(!(x >= y), x < y); + TVM_TRY_REWRITE(!(x < y), y <= x); + TVM_TRY_REWRITE(!(x > y), x <= y); + return ret; + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + With ctx1(analyzer_, op->loop_var >= op->min); + With ctx2(analyzer_, op->loop_var < op->min + op->extent); + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + // Recursive rewrite x + // we limit maximum depth of recursive rewrite allowed to + // avoid infinite loop + PrimExpr RecursiveRewrite(const PrimExpr& x) { + if (recur_depth_ >= kMaxRecurDepth) return x; + ++recur_depth_; + PrimExpr res = this->VisitExpr(x); + --recur_depth_; + return res; + } + + private: + // counter to record recursive rewrite depth. + int recur_depth_{0}; + // maximum number of recursion allowed during a single pass. + static const constexpr int kMaxRecurDepth = 5; +}; + +namespace transform { + +Pass RenormalizeSplitPattern() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") + .set_body_typed(RenormalizeSplitPattern); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index aa586846f5d4..9c1aab634496 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -58,8 +58,15 @@ class ThreadBindingUnifier : public StmtExprMutator { if (op->kind != ForKind::kThreadBinding) { return StmtExprMutator::VisitStmt_(op); } - return UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), - Range::FromMinExtent(op->min, op->extent)); + Map annotations = op->annotations; + Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), + Range::FromMinExtent(op->min, op->extent)); + if (annotations.empty()) { + return stmt; + } + For new_loop = Downcast(stmt); + new_loop.CopyOnWrite()->annotations = std::move(annotations); + return new_loop; } template @@ -70,7 +77,7 @@ class ThreadBindingUnifier : public StmtExprMutator { const String& thread_tag = old_iter_var->thread_tag; // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the - // thread block depth is 0 before the increasement, it means we are entering a new kernel, and + // thread block depth is 0 before the increment, it means we are entering a new kernel, and // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have // thread axes with different extents. bool is_kernel_launch_scope = false; diff --git a/tests/python/meta_schedule/run_ansor_cpu.sh b/tests/python/meta_schedule/run_ansor_cpu.sh new file mode 100644 index 000000000000..a080ded8fdd9 --- /dev/null +++ b/tests/python/meta_schedule/run_ansor_cpu.sh @@ -0,0 +1,41 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +NUM_TRIALS=800 +LOG_DIR=$HOME/logs/ansor-cpu/ + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run SFM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/tests/python/meta_schedule/run_ansor_cuda.sh b/tests/python/meta_schedule/run_ansor_cuda.sh new file mode 100644 index 000000000000..6eda12fe119c --- /dev/null +++ b/tests/python/meta_schedule/run_ansor_cuda.sh @@ -0,0 +1,39 @@ +# set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="jetson-agx-xavier" +TARGET="nvidia/jetson-agx-xavier" +LOG_DIR=$HOME/logs/ansor-cuda/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_ansor.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials "$NUM_TRIALS" \ + --log-dir $LOG_DIR \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +run C1D +run C2D +run CAP +run DEP +run DIL +run GMM +run GRP +run T2D +run C2d-BN-RELU +run TBG + +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/run_meta_schedule_cpu.sh b/tests/python/meta_schedule/run_meta_schedule_cpu.sh new file mode 100644 index 000000000000..87bc17f9e8b6 --- /dev/null +++ b/tests/python/meta_schedule/run_meta_schedule_cpu.sh @@ -0,0 +1,40 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="raspi4b-aarch64" +TARGET="raspberry-pi/4b-64" +LOG_DIR=$HOME/logs/ms-cpu/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + echo "Running workload $name" + python tests/python/meta_schedule/test_meta_schedule.py \ + --workload "$name" \ + --target "$TARGET" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials $NUM_TRIALS \ + 2>&1 | tee "$LOG_DIR/$name.log" +} + +# Single op +run C1D +run C2D +run C3D +run CAP +run DEP +run DIL +run GMM +run GRP +run NRM +run SFM +run T2D +# Subgraph +run C2d-BN-RELU +run TBG + diff --git a/tests/python/meta_schedule/run_meta_schedule_cuda.sh b/tests/python/meta_schedule/run_meta_schedule_cuda.sh new file mode 100644 index 000000000000..28132a05045a --- /dev/null +++ b/tests/python/meta_schedule/run_meta_schedule_cuda.sh @@ -0,0 +1,41 @@ +set -euxo pipefail + +RPC_HOST="192.168.6.66" +RPC_PORT="4445" +RPC_KEY="jetson-agx-xavier" +TARGET="nvidia/jetson-agx-xavier" +LOG_DIR=$HOME/logs/ms-cuda/ +NUM_TRIALS=2000 + +mkdir -p $LOG_DIR + +run () { + name=$1 + work_dir=$LOG_DIR/$name/ + mkdir -p $work_dir + echo "Running workload $name" + python tests/python/meta_schedule/test_meta_schedule.py \ + --workload "$name" \ + --target "$TARGET" \ + --work-dir "$work_dir" \ + --rpc-host "$RPC_HOST" \ + --rpc-port "$RPC_PORT" \ + --rpc-key "$RPC_KEY" \ + --num-trials $NUM_TRIALS \ + 2>&1 | tee "$work_dir/$name.log" +} + +run C1D +run C2D +run CAP +run DEP +run DIL +run GMM +run GRP +run T2D +run C2d-BN-RELU +run TBG + +run C3D +run NRM +run SFM diff --git a/tests/python/meta_schedule/test_ansor.py b/tests/python/meta_schedule/test_ansor.py new file mode 100644 index 000000000000..1e548c49afa3 --- /dev/null +++ b/tests/python/meta_schedule/test_ansor.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import os + +import tvm +from tvm import auto_scheduler +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.te_workload import CONFIGS + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + args.add_argument( + "--log-dir", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=60, + ) + parsed.rpc_workers = rpc_config.count_num_servers(allow_missing=False) + return parsed + + +ARGS = _parse_args() + + +def main(): + log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json") + workload_func, params = CONFIGS[ARGS.workload] + params = params[0] + workload_func = auto_scheduler.register_workload(workload_func) + + if ARGS.target.device_name == "cpu": + hardware_params = auto_scheduler.HardwareParams( + num_cores=int(ARGS.target.attrs["num-cores"]), + target=ARGS.target, + ) + else: + hardware_params = auto_scheduler.HardwareParams( + num_cores=-1, + vector_unit_bytes=16, + cache_line_bytes=64, + max_shared_memory_per_block=int(ARGS.target.attrs["shared_memory_per_block"]), + max_local_memory_per_block=int(ARGS.target.attrs["registers_per_block"]), + max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]), + max_vthread_extent=8, + warp_size=32, + ) + task = auto_scheduler.SearchTask( + func=workload_func, + args=params, + target=ARGS.target, + hardware_params=hardware_params, + ) + runner = auto_scheduler.RPCRunner( + key=ARGS.rpc_key, + host=ARGS.rpc_host, + port=ARGS.rpc_port, + n_parallel=ARGS.rpc_workers, + ) + + # Inspect the computational graph + print("Computational DAG:") + print(task.compute_dag) + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=ARGS.num_trials, + measure_callbacks=[auto_scheduler.RecordToFile(log_file)], + verbose=2, + runner=runner, + ) + print("Running AutoTuning:") + task.tune(tune_option) + print("History Best:") + print(task.print_best(log_file)) + sch, args = task.apply_best(log_file) + print("Lowered TIR:") + print(tvm.lower(sch, args, simple_mode=True)) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_debug_ansor.py b/tests/python/meta_schedule/test_debug_ansor.py new file mode 100644 index 000000000000..be562963a1a0 --- /dev/null +++ b/tests/python/meta_schedule/test_debug_ansor.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import Tuple + +import tvm +from tvm import te, topi + + +TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + import os + if not os.path.exists("/tmp/perf"): + os.mkdir("/tmp/perf") + with open("/tmp/perf/te.cu", "w") as f: + f.write(code) + return code + + +def func( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, padded, output) + + +def main(): + inputs, weight, PadInput, conv1d_nlc = func(1, 256, 64, 128, 3, 2, 1) + s = te.create_schedule(conv1d_nlc.op) + # fmt: off + PadInput_i0, PadInput_i1, PadInput_i2 = tuple(PadInput.op.axis) + tuple(PadInput.op.reduce_axis) + conv1d_nlc_n, conv1d_nlc_l, conv1d_nlc_co, conv1d_nlc_rl, conv1d_nlc_rc = tuple(conv1d_nlc.op.axis) + tuple(conv1d_nlc.op.reduce_axis) + conv1d_nlc_local, = s.cache_write([conv1d_nlc], "local") + conv1d_nlc_local_n_c, conv1d_nlc_local_l_c, conv1d_nlc_local_co_c, conv1d_nlc_local_rl, conv1d_nlc_local_rc = tuple(conv1d_nlc_local.op.axis) + tuple(conv1d_nlc_local.op.reduce_axis) + conv1d_nlc_local_n_c_o_i, conv1d_nlc_local_n_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c, factor=1) + conv1d_nlc_local_n_c_o_o_i, conv1d_nlc_local_n_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_i, factor=1) + conv1d_nlc_local_n_c_o_o_o_i, conv1d_nlc_local_n_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_o_i, factor=1) + conv1d_nlc_local_n_c_o_o_o_o, conv1d_nlc_local_n_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_n_c_o_o_o_i, factor=1) + conv1d_nlc_local_l_c_o_i, conv1d_nlc_local_l_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c, factor=1) + conv1d_nlc_local_l_c_o_o_i, conv1d_nlc_local_l_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_i, factor=4) + conv1d_nlc_local_l_c_o_o_o_i, conv1d_nlc_local_l_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_o_i, factor=8) + conv1d_nlc_local_l_c_o_o_o_o, conv1d_nlc_local_l_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_l_c_o_o_o_i, factor=1) + conv1d_nlc_local_co_c_o_i, conv1d_nlc_local_co_c_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c, factor=2) + conv1d_nlc_local_co_c_o_o_i, conv1d_nlc_local_co_c_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_i, factor=1) + conv1d_nlc_local_co_c_o_o_o_i, conv1d_nlc_local_co_c_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_o_i, factor=16) + conv1d_nlc_local_co_c_o_o_o_o, conv1d_nlc_local_co_c_o_o_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_co_c_o_o_o_i, factor=1) + conv1d_nlc_local_rl_o_i, conv1d_nlc_local_rl_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rl, factor=3) + conv1d_nlc_local_rl_o_o, conv1d_nlc_local_rl_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rl_o_i, factor=1) + conv1d_nlc_local_rc_o_i, conv1d_nlc_local_rc_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rc, factor=2) + conv1d_nlc_local_rc_o_o, conv1d_nlc_local_rc_o_i = s[conv1d_nlc_local].split(conv1d_nlc_local_rc_o_i, factor=8) + s[conv1d_nlc_local].reorder(conv1d_nlc_local_n_c_o_o_o_o, conv1d_nlc_local_l_c_o_o_o_o, conv1d_nlc_local_co_c_o_o_o_o, conv1d_nlc_local_n_c_o_o_o_i, conv1d_nlc_local_l_c_o_o_o_i, conv1d_nlc_local_co_c_o_o_o_i, conv1d_nlc_local_n_c_o_o_i, conv1d_nlc_local_l_c_o_o_i, conv1d_nlc_local_co_c_o_o_i, conv1d_nlc_local_rl_o_o, conv1d_nlc_local_rc_o_o, conv1d_nlc_local_rl_o_i, conv1d_nlc_local_rc_o_i, conv1d_nlc_local_n_c_o_i, conv1d_nlc_local_l_c_o_i, conv1d_nlc_local_co_c_o_i, conv1d_nlc_local_rl_i, conv1d_nlc_local_rc_i, conv1d_nlc_local_n_c_i, + conv1d_nlc_local_l_c_i, conv1d_nlc_local_co_c_i) + conv1d_nlc_n_o_i, conv1d_nlc_n_i = s[conv1d_nlc].split(conv1d_nlc_n, factor=1) + conv1d_nlc_n_o_o_i, conv1d_nlc_n_o_i = s[conv1d_nlc].split(conv1d_nlc_n_o_i, factor=1) + conv1d_nlc_n_o_o_o, conv1d_nlc_n_o_o_i = s[conv1d_nlc].split(conv1d_nlc_n_o_o_i, factor=1) + conv1d_nlc_l_o_i, conv1d_nlc_l_i = s[conv1d_nlc].split(conv1d_nlc_l, factor=4) + conv1d_nlc_l_o_o_i, conv1d_nlc_l_o_i = s[conv1d_nlc].split(conv1d_nlc_l_o_i, factor=8) + conv1d_nlc_l_o_o_o, conv1d_nlc_l_o_o_i = s[conv1d_nlc].split(conv1d_nlc_l_o_o_i, factor=1) + conv1d_nlc_co_o_i, conv1d_nlc_co_i = s[conv1d_nlc].split(conv1d_nlc_co, factor=2) + conv1d_nlc_co_o_o_i, conv1d_nlc_co_o_i = s[conv1d_nlc].split(conv1d_nlc_co_o_i, factor=16) + conv1d_nlc_co_o_o_o, conv1d_nlc_co_o_o_i = s[conv1d_nlc].split(conv1d_nlc_co_o_o_i, factor=1) + s[conv1d_nlc].reorder(conv1d_nlc_n_o_o_o, conv1d_nlc_l_o_o_o, conv1d_nlc_co_o_o_o, conv1d_nlc_n_o_o_i, conv1d_nlc_l_o_o_i, conv1d_nlc_co_o_o_i, conv1d_nlc_n_o_i, conv1d_nlc_l_o_i, conv1d_nlc_co_o_i, conv1d_nlc_n_i, conv1d_nlc_l_i, conv1d_nlc_co_i) + s[conv1d_nlc_local].compute_at(s[conv1d_nlc], conv1d_nlc_co_o_i) + weight_shared = s.cache_read(weight, "shared", [conv1d_nlc_local]) + weight_shared_ax0, weight_shared_ax1, weight_shared_ax2 = tuple(weight_shared.op.axis) + s[weight_shared].compute_at(s[conv1d_nlc_local], conv1d_nlc_local_rc_o_o) + PadInput_shared = s.cache_read(PadInput, "shared", [conv1d_nlc_local]) + PadInput_shared_ax0, PadInput_shared_ax1, PadInput_shared_ax2 = tuple(PadInput_shared.op.axis) + s[PadInput_shared].compute_at(s[conv1d_nlc_local], conv1d_nlc_local_rc_o_o) + s[PadInput].compute_inline() + conv1d_nlc_n_o_o_o_l_o_o_o_fused_co_o_o_o_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_o_o, conv1d_nlc_l_o_o_o, conv1d_nlc_co_o_o_o) + s[conv1d_nlc].bind(conv1d_nlc_n_o_o_o_l_o_o_o_fused_co_o_o_o_fused, te.thread_axis("blockIdx.x")) + conv1d_nlc_n_o_o_i_l_o_o_i_fused_co_o_o_i_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_o_i, conv1d_nlc_l_o_o_i, conv1d_nlc_co_o_o_i) + s[conv1d_nlc].bind(conv1d_nlc_n_o_o_i_l_o_o_i_fused_co_o_o_i_fused, te.thread_axis("vthread")) + conv1d_nlc_n_o_i_l_o_i_fused_co_o_i_fused = s[conv1d_nlc].fuse(conv1d_nlc_n_o_i, conv1d_nlc_l_o_i, conv1d_nlc_co_o_i) + s[conv1d_nlc].bind(conv1d_nlc_n_o_i_l_o_i_fused_co_o_i_fused, te.thread_axis("threadIdx.x")) + weight_shared_ax0_ax1_fused_ax2_fused = s[weight_shared].fuse(weight_shared_ax0, weight_shared_ax1, weight_shared_ax2) + weight_shared_ax0_ax1_fused_ax2_fused_o, weight_shared_ax0_ax1_fused_ax2_fused_i = s[weight_shared].split(weight_shared_ax0_ax1_fused_ax2_fused, factor=1) + s[weight_shared].vectorize(weight_shared_ax0_ax1_fused_ax2_fused_i) + weight_shared_ax0_ax1_fused_ax2_fused_o_o, weight_shared_ax0_ax1_fused_ax2_fused_o_i = s[weight_shared].split(weight_shared_ax0_ax1_fused_ax2_fused_o, factor=128) + s[weight_shared].bind(weight_shared_ax0_ax1_fused_ax2_fused_o_i, te.thread_axis("threadIdx.x")) + PadInput_shared_ax0_ax1_fused_ax2_fused = s[PadInput_shared].fuse(PadInput_shared_ax0, PadInput_shared_ax1, PadInput_shared_ax2) + PadInput_shared_ax0_ax1_fused_ax2_fused_o, PadInput_shared_ax0_ax1_fused_ax2_fused_i = s[PadInput_shared].split(PadInput_shared_ax0_ax1_fused_ax2_fused, factor=1) + s[PadInput_shared].vectorize(PadInput_shared_ax0_ax1_fused_ax2_fused_i) + PadInput_shared_ax0_ax1_fused_ax2_fused_o_o, PadInput_shared_ax0_ax1_fused_ax2_fused_o_i = s[PadInput_shared].split(PadInput_shared_ax0_ax1_fused_ax2_fused_o, factor=128) + s[PadInput_shared].bind(PadInput_shared_ax0_ax1_fused_ax2_fused_o_i, te.thread_axis("threadIdx.x")) + # s[conv1d_nlc_local].pragma(conv1d_nlc_local_n_c_o_o_o_o, "auto_unroll_max_step", 1024) + # s[conv1d_nlc_local].pragma(conv1d_nlc_local_n_c_o_o_o_o, "unroll_explicit", True) + # fmt: off + print(tvm.lower(s, [inputs, weight, conv1d_nlc]).script()) + tvm.build(s, [inputs, weight, conv1d_nlc], target=TARGET) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_debug_meta_schedule.py b/tests/python/meta_schedule/test_debug_meta_schedule.py new file mode 100644 index 000000000000..b93a01dae737 --- /dev/null +++ b/tests/python/meta_schedule/test_debug_meta_schedule.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +from typing import List + +import tvm +from tvm import meta_schedule as ms +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import Postproc +from tvm.meta_schedule.testing import create_te_workload +from tvm.meta_schedule.tune import DefaultCUDA, DefaultLLVM +from tvm.meta_schedule.utils import remove_build_dir +from tvm.target import Target +from tvm.tir import Schedule + + +RPC_HOST = "192.168.6.66" +RPC_PORT = 4445 +RPC_KEY = "jetson-agx-xavier" +TARGET = Target("nvidia/jetson-agx-xavier") +WORKLOAD = "C1D" +POSTPROCS: List[Postproc] = DefaultCUDA._postproc() # pylint: disable=protected-access + +TARGET = tvm.target.Target("nvidia/jetson-agx-xavier") + + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + import os + + if not os.path.exists("/tmp/perf"): + os.mkdir("/tmp/perf") + with open("/tmp/perf/tir.cu", "w") as f: + f.write(code) + return code + + +def schedule_fn(sch: Schedule): + # pylint: disable=invalid-name,line-too-long,unused-variable + # fmt: off + b0 = sch.get_block(name="PadInput", func_name="main") + b1 = sch.get_block(name="conv1d_nlc", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local") + l4, l5, l6, l7, l8 = sch.get_loops(block=b1) + v9, v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l14, l15, l16, l17, l18 = sch.split(loop=l4, factors=[v9, v10, v11, v12, v13]) + v19, v20, v21, v22, v23 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[4, 1, 8, 4, 1]) + l24, l25, l26, l27, l28 = sch.split(loop=l5, factors=[v19, v20, v21, v22, v23]) + v29, v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[4, 1, 16, 1, 2]) + l34, l35, l36, l37, l38 = sch.split(loop=l6, factors=[v29, v30, v31, v32, v33]) + v39, v40, v41 = sch.sample_perfect_tile(loop=l7, n=3, max_innermost_factor=64, decision=[1, 1, 3]) + l42, l43, l44 = sch.split(loop=l7, factors=[v39, v40, v41]) + v45, v46, v47 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64, decision=[4, 8, 2]) + l48, l49, l50 = sch.split(loop=l8, factors=[v45, v46, v47]) + sch.reorder(l14, l24, l34, l15, l25, l35, l16, l26, l36, l42, l48, l43, l49, l17, l27, l37, l44, l50, l18, l28, l38) + l51 = sch.fuse(l14, l24, l34) + sch.bind(loop=l51, thread_axis="blockIdx.x") + l52 = sch.fuse(l15, l25, l35) + sch.bind(loop=l52, thread_axis="vthread.x") + l53 = sch.fuse(l16, l26, l36) + sch.bind(loop=l53, thread_axis="threadIdx.x") + + b54 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b54, loop=l48, preserve_unit_loops=True) + l55, l56, l57, l58, l59, l60, l61, l62 = sch.get_loops(block=b54) + l63 = sch.fuse(l60, l61, l62) + v64, v65 = sch.sample_perfect_tile(loop=l63, n=2, max_innermost_factor=4, decision=[1040, 1]) + sch.annotate(block_or_loop=b54, ann_key="meta_schedule.cooperative_fetch", ann_val=v65) + + b66 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b66, loop=l48, preserve_unit_loops=True) + l67, l68, l69, l70, l71, l72, l73, l74 = sch.get_loops(block=b66) + l75 = sch.fuse(l72, l73, l74) + v76, v77 = sch.sample_perfect_tile(loop=l75, n=2, max_innermost_factor=4, decision=[1536, 1]) + sch.annotate(block_or_loop=b66, ann_key="meta_schedule.cooperative_fetch", ann_val=v77) + + sch.reverse_compute_at(block=b3, loop=l53, preserve_unit_loops=True) + sch.compute_inline(block=b0) + # v78 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=4) + # sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v78) + # fmt: on + return sch + + +def _make_sch() -> Schedule: + prim_func = create_te_workload(WORKLOAD, 0) + prim_func = prim_func.with_attr("global_symbol", "main") + prim_func = prim_func.with_attr("tir.noalias", True) + mod = IRModule({"main": prim_func}) + return Schedule(mod, debug_mask="all") + + +def _apply_postproc(sch: Schedule): + sch.enter_postproc() + ctx = TuneContext(target=TARGET) + for p in POSTPROCS: + p.initialize_with_tune_context(ctx) + assert p.apply(sch) + + +def run_sch(sch: Schedule): + print(sch.mod.script()) + print(sch.trace) + print(tvm.lower(sch.mod).script()) + tvm.build(sch.mod, target=TARGET) + builder = ms.builder.LocalBuilder() + runner = ms.runner.RPCRunner( + rpc_config=ms.runner.RPCConfig( + tracker_host=RPC_HOST, + tracker_port=RPC_PORT, + tracker_key=RPC_KEY, + session_timeout_sec=60, + ), + alloc_repeat=3, + max_workers=5, + ) + (builder_result,) = builder.build( # pylint: disable=unbalanced-tuple-unpacking + [ms.builder.BuilderInput(sch.mod, TARGET)] + ) + if builder_result.error_msg is not None: + print(builder_result.error_msg) + return + try: + runner_input = ms.runner.RunnerInput( + builder_result.artifact_path, + device_type=TARGET.kind.name, + args_info=ms.arg_info.ArgInfo.from_prim_func(sch.mod["main"]), + ) + (runner_future,) = runner.run([runner_input]) # pylint: disable=unbalanced-tuple-unpacking + runner_result = runner_future.result() + if runner_result.error_msg is not None: + print(runner_result.error_msg) + else: + print([float(x) * 1000.0 for x in runner_result.run_secs]) + finally: + remove_build_dir(builder_result.artifact_path) + + +def main(): + sch = schedule_fn(_make_sch()) + _apply_postproc(sch) + run_sch(sch) + + +if __name__ == "__main__": + main() diff --git a/tests/python/meta_schedule/test_meta_schedule.py b/tests/python/meta_schedule/test_meta_schedule.py new file mode 100644 index 000000000000..64890f426791 --- /dev/null +++ b/tests/python/meta_schedule/test_meta_schedule.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import argparse +import logging +from os import cpu_count + +import tvm +from tvm import meta_schedule as ms +from tvm import tir +from tvm.meta_schedule.testing import create_te_workload + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument( + "--workload", + type=str, + required=True, + ) + args.add_argument( + "--target", + type=str, + required=True, + ) + args.add_argument( + "--num-trials", + type=int, + required=True, + ) + args.add_argument( + "--work-dir", + type=str, + required=True, + ) + args.add_argument( + "--rpc-host", + type=str, + required=True, + ) + args.add_argument( + "--rpc-port", + type=int, + required=True, + ) + args.add_argument( + "--rpc-key", + type=str, + required=True, + ) + parsed = args.parse_args() + parsed.target = tvm.target.Target(parsed.target) + if parsed.target.attrs.get("mtriple", None) == "aarch64-linux-gnu": + parsed.alloc_repeat = 3 + else: + parsed.alloc_repeat = 1 + parsed.rpc_config = ms.runner.RPCConfig( + tracker_host=parsed.rpc_host, + tracker_port=parsed.rpc_port, + tracker_key=parsed.rpc_key, + session_timeout_sec=30, + ) + parsed.rpc_workers = parsed.rpc_config.count_num_servers(allow_missing=False) + return parsed + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +ARGS = _parse_args() + + +def main(): + runner = ms.runner.RPCRunner( + rpc_config=ARGS.rpc_config, + alloc_repeat=3, + max_workers=ARGS.rpc_workers, + ) + sch: tir.Schedule = ms.tune_tir( + mod=create_te_workload(ARGS.workload, 0), + target=ARGS.target, + config=ms.EvolutionarySearchConfig( + num_trials_per_iter=64, + num_trials_total=ARGS.num_trials, + init_min_unmeasured=50 + ), + runner=runner, + task_name=ARGS.workload, + work_dir=ARGS.work_dir, + num_threads=cpu_count(), + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == "__main__": + main() diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index b40f3c9f56ea..e8fae11d5cc9 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -102,7 +102,14 @@ def test_mod(): floordiv = tvm.te.floordiv z = te.var("z") ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3)) - ck.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (0, 7)) + ck.verify( + flm(y, 8), + {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, + ( + z * 8 + x * 4 - 8 * floordiv(z * 8 + x * 4, 8), + z * 8 + x * 4 + 3 - 8 * floordiv(z * 8 + x * 4, 8), + ), + ) ck1 = IntSetChecker() ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2)) ck1.verify( diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 6b3c29592eb6..e9ea09d0b793 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -625,7 +625,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) tvm.ir.assert_structural_equal(res[1][0], 0) - tvm.ir.assert_structural_equal(res[1][1], floormod(floordiv(l0[0] * 6 + l1[0], 3), 2)) + tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) tvm.ir.assert_structural_equal(res[2][0], 0) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 4a4cd6a31ef1..7914195effe1 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.arith import analyzer def test_cast(): @@ -50,6 +51,14 @@ def test_mul(): assert m.base == 2 +def test_floormod(): + analyzer = tvm.arith.Analyzer() + x, y = te.var("x"), te.var("y") + m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256)) + assert m.coeff == 4 + assert m.base == 0 + + def test_div_shift(): analyzer = tvm.arith.Analyzer() x, y = te.var("x"), te.var("y") @@ -175,6 +184,7 @@ def test_let(): test_add_sub() test_mul() test_div_shift() + test_floormod() test_min_max_select() test_mix_index() test_constraint_scope() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 6ca2a2a5fcb0..549882126d50 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -80,6 +80,10 @@ def test_vector_simplify(): ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) + ck.verify( + fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), + tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4) + ) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), @@ -277,6 +281,7 @@ def test_add_index_simplify(): flm = tvm.te.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x) + ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)) def test_sub_index_simplify(): diff --git a/tests/python/unittest/test_meta_schedule_byoc.py b/tests/python/unittest/test_meta_schedule_byoc.py new file mode 100644 index 000000000000..fe50350d5133 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ +# pylint: disable=missing-docstring + +import sys + +import pytest +import tvm +from tvm import relay +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput +from tvm.meta_schedule.testing import get_network +from tvm.meta_schedule.testing.byoc_trt import ( + build_relay, + build_relay_with_tensorrt, + run_with_graph_executor, +) +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +from tvm.target import Target +from tvm.tir import FloatImm + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, + params, + data_shape, + use_meta_sched: bool = True, + use_trt: bool = True, + mode: str = "vm", +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + # Build + builder = LocalBuilder( + f_build=build_relay_with_tensorrt if use_trt else build_relay, + timeout_sec=1000, + ) + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + builder_result = builder.build([builder_input])[0] + assert builder_result.error_msg is None, builder_result.error_msg + assert builder_result.artifact_path is not None + + # Run + runner_input = RunnerInput( + builder_result.artifact_path, + device_type="cuda", + args_info=[TensorInfo("float32", data_shape)], + ) + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ), + f_run_evaluator=run_with_graph_executor, + ) + + # Run the module + runner_future = runner.run([runner_input])[0] + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.error_msg is None, runner_result.error_msg + assert runner_result.run_secs is not None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + _func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + _func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +@has_tensorrt_codegen +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@has_tensorrt_codegen +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + mod, params, input_shape, _oshape = get_network( + name=model_name, + batch_size=batch_size, + ) + verify_meta_schedule_with_tensorrt( + mod, + params, + input_shape, + use_meta_sched=use_meta_sched, + use_trt=use_trt, + mode="vm", + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index 3f98d711ea61..cdc72d30b605 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -14,23 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +from typing import List + +import tempfile import os import re -import shutil import sys -import tempfile -from typing import List - -import numpy as np +import shutil import pytest +import numpy as np + import tvm -from tvm.meta_schedule.cost_model import PyCostModel, RandomModel -from tvm.meta_schedule.runner import RunnerResult -from tvm.meta_schedule.search_strategy import MeasureCandidate -from tvm.meta_schedule.tune_context import TuneContext from tvm.script import tir as T from tvm.tir.schedule.schedule import Schedule +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.tune_context import TuneContext # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module @@ -139,5 +140,81 @@ def test_meta_schedule_random_model_reload(): assert (res1 == res2).all() +def _dummy_candidate(): + return MeasureCandidate(Schedule(Matmul), []) + + +def _dummy_result(num_samples: int = 4, max_run_sec: int = 10): + return RunnerResult(list(np.random.rand(num_samples) * max_run_sec + 1e-6), None) + + +def test_meta_schedule_xgb_model(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=2) + update_sample_count = 10 + predict_sample_count = 100 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + + +def test_meta_schedule_xgb_model_reload(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=10) + update_sample_count = 20 + predict_sample_count = 30 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + random_state = model.extractor.random_state # save feature extractor's random state + path = os.path.join(tempfile.mkdtemp(), "test_output_meta_schedule_xgb_model.bin") + cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) + model.save(path) + res1 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + model.extractor.random_state = random_state # load feature extractor's random state + model.cached_features = None + model.cached_mean_costs = None + model.load(path) + new_cached = (model.cached_features.copy(), model.cached_mean_costs.copy()) + res2 = model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + shutil.rmtree(os.path.dirname(path)) + assert (res1 == res2).all() + # cached feature does not change + assert len(cached[0]) == len(new_cached[0]) + for i in range(len(cached[0])): + assert (cached[0][i] == new_cached[0][i]).all() + # cached meaen cost does not change + assert (cached[1] == new_cached[1]).all() + + +def test_meta_schedule_xgb_model_reupdate(): + extractor = RandomFeatureExtractor() + model = XGBModel(extractor=extractor, num_warmup_samples=2) + update_sample_count = 60 + predict_sample_count = 100 + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.update( + TuneContext(), + [_dummy_candidate() for i in range(update_sample_count)], + [_dummy_result() for i in range(update_sample_count)], + ) + model.predict(TuneContext(), [_dummy_candidate() for i in range(predict_sample_count)]) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor.py b/tests/python/unittest/test_meta_schedule_feature_extractor.py index 143d446f48fd..4f068d7a8313 100644 --- a/tests/python/unittest/test_meta_schedule_feature_extractor.py +++ b/tests/python/unittest/test_meta_schedule_feature_extractor.py @@ -15,13 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -import re from typing import List +import re import numpy as np + from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.feature_extractor import PyFeatureExtractor from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.feature_extractor import PyFeatureExtractor def test_meta_schedule_feature_extractor(): diff --git a/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py new file mode 100644 index 000000000000..2db43caeacaa --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_feature_extractor_per_store_feature.py @@ -0,0 +1,1536 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import Callable, List + +from numpy.testing import assert_allclose +import tvm +from tvm import meta_schedule as ms, te, tir +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T + +N_FEATURES = 164 + + +def _make_context(target) -> ms.TuneContext: + return ms.TuneContext( + target=target, + num_threads=1, + ) + + +def _make_candidate(f_sch: Callable[[], tir.Schedule]) -> ms.MeasureCandidate: + return ms.MeasureCandidate(sch=f_sch(), args_info=[]) + + +def _feature_names( # pylint: disable=invalid-name + buffers_per_store: int = 5, + arith_intensity_curve_num_samples: int = 10, +) -> List[str]: + result = [ + "float_mad", + "float_addsub", + "float_mul", + "float_divmod", + "float_cmp", + "float_mathfunc", + "float_otherfunc", + "int_mad", + "int_addsub", + "int_mul", + "int_divmod", + "int_cmp", + "int_mathfunc", + "int_otherfunc", + "bool_op", + "select_op", + "vec_num", + "vec_prod", + "vec_len", + "vec_type.kPosNone", + "vec_type.kPosInnerSpatial", + "vec_type.kPosMiddleSpatial", + "vec_type.kPosOuterSpatial", + "vec_type.kPosInnerReduce", + "vec_type.kPosMiddleReduce", + "vec_type.kPosOuterReduce", + "vec_type.kPosMixed", + "unroll_num", + "unroll_prod", + "unroll_len", + "unroll_type.kPosNone", + "unroll_type.kPosInnerSpatial", + "unroll_type.kPosMiddleSpatial", + "unroll_type.kPosOuterSpatial", + "unroll_type.kPosInnerReduce", + "unroll_type.kPosMiddleReduce", + "unroll_type.kPosOuterReduce", + "unroll_type.kPosMixed", + "parallel_num", + "parallel_prod", + "parallel_len", + "parallel_type.kPosNone", + "parallel_type.kPosInnerSpatial", + "parallel_type.kPosMiddleSpatial", + "parallel_type.kPosOuterSpatial", + "parallel_type.kPosInnerReduce", + "parallel_type.kPosMiddleReduce", + "parallel_type.kPosOuterReduce", + "parallel_type.kPosMixed", + "is_gpu", + "blockIdx_x_len", + "blockIdx_y_len", + "blockIdx_z_len", + "threadIdx_x_len", + "threadIdx_y_len", + "threadIdx_z_len", + "vthread_len", + ] + for i in range(buffers_per_store): + result.extend( + f"B{i}.{s}" + for s in [ + "acc_type.kRead", + "acc_type.kWrite", + "acc_type.kReadWrite", + "bytes", + "unique_bytes", + "lines", + "unique_lines", + "reuse_type.kLoopMultipleRead", + "reuse_type.kSerialMultipleReadWrite", + "reuse_type.kNoReuse", + "reuse_dis_iter", + "reuse_dis_bytes", + "reuse_ct", + "bytes_d_reuse_ct", + "unique_bytes_d_reuse_ct", + "lines_d_reuse_ct", + "unique_lines_d_reuse_ct", + "stride", + ] + ) + result.extend(f"arith_intensity_curve_{i}" for i in range(arith_intensity_curve_num_samples)) + result.extend( + [ + "alloc_size", + "alloc_prod", + "alloc_outer_prod", + "alloc_inner_prod", + "outer_prod", + "num_loops", + "auto_unroll_max_step", + ] + ) + # 57 + 18 * 5 + 10 + 4 + 3 + assert len(result) == N_FEATURES + return result + + +def _zip_feature(feature, names): + assert feature.ndim == 1 + assert feature.shape[0] == N_FEATURES + assert len(names) == N_FEATURES + return list(zip(names, feature)) + + +def _print_feature(feature, st, ed): # pylint: disable=invalid-name + named_feature = _zip_feature(feature, _feature_names()) + for k, v in named_feature[st:ed]: + print("\t", k, v) + + +def test_cpu_matmul(): + def _create_schedule(): + func = te.create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i_o, i_i = sch.split(i, factors=[None, 16]) # outer: 32 + j_o, j_i = sch.split(j, factors=[None, 8]) # outer: 64 + sch.reorder(i_o, j_o, k, j_i, i_i) + sch.vectorize(j_i) + sch.parallel(i_o) + sch.parallel(j_o) + sch.unroll(k) + return sch + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("llvm")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (1, N_FEATURES) + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[ + # float math ops + 0, 27, 27, 0, 0, 0, 0, + # int math ops + 0, 29, 29, 0, 0, 0, 0, + # bool/select ops + 0, 0, + ], + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[1.0, 3.169924, 3.169924, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[1.0, 9.002815, 9.002815, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[1.58496, 11.0007, 6.022368, 0, 0, 0, 0, 0, 0, 0, 1], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1, + 0, + 0, + 29, + 20, + 27, + 14, + 1, + 0, + 0, + 4.087463, + 7.0552826, + 3.169925, + 26, + 17, + 24, + 11.0007038, + 9.002815, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 0.0, + 1.0, + 29.0, + 20.000001907348633, + 27.0, + 14.00008773803711, + 1.0, + 0.0, + 0.0, + 7.011227130889893, + 9.250298500061035, + 9.002815246582031, + 20.000001907348633, + 11.000703811645508, + 18.0000057220459, + 5.044394016265869, + 9.002815246582031, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Buffer B + assert_allclose( + actual=f[93:111], + desired=[ + 1.0, + 0.0, + 0.0, + 29.0, + 20.000001907348633, + 19.000001907348633, + 14.00008773803711, + 1.0, + 0.0, + 0.0, + 1.0, + 3.700439691543579, + 4.087462902069092, + 25.0, + 16.000022888183594, + 15.000043869018555, + 10.001408576965332, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[ + 0.7097842693328857, + 0.7408391237258911, + 0.8750449419021606, + 0.9449487924575806, + 1.0148526430130005, + 1.0847564935684204, + 1.113688349723816, + 1.1394684314727783, + 1.2119636535644531, + 1.2971993684768677, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 20.000001907348633, + 18.0000057220459, + 1.0, + 27.0, + 27.0, + 2.5849626064300537, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + +def test_cpu_fusion(): + # pylint: disable=all + @T.prim_func + def func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [64, 32], dtype="float32") + B = T.match_buffer(b, [64, 32], dtype="float32") + C = T.match_buffer(c, [64, 32], dtype="float32") + for i, j in T.grid(64, 32): # type: ignore + with T.block(): + T.reads([A[i, j], B[i, j]]) # type: ignore + T.writes([B[i, j], C[i, j]]) # type: ignore + with T.block("B"): + T.reads([A[i, j]]) # type: ignore + T.writes([B[i, j]]) # type: ignore + B[i, j] = A[i, j] # type: ignore + with T.block("C"): + T.reads([B[i, j]]) # type: ignore + T.writes([C[i, j]]) # type: ignore + C[i, j] = B[i, j] # type: ignore + + # pylint: enable=all + + def _create_schedule(): + return tir.Schedule(func, debug_mask="all") + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("llvm")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (2, N_FEATURES) + ## Features for BufferStore(B) + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[0.0] * 16, + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer B + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 13.000176, + 11.000703811645508, + 1.0, + 11.000703811645508, + 11.000703811645508, + 1.5849624872207642, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + ## Features for BufferStore(C) + f = feature[1] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + # fmt: off + desired=[0.0] * 16, + # fmt: on + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 1.0, + 0.0, + 1.0, + 4.087462902069092, + 1.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 13.000176429748535, + 13.000176429748535, + 7.011227130889893, + 7.011227130889893, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 14.00008773803711, + 14.00008773803711, + 8.005624771118164, + 8.005624771118164, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[0.0] * 18, + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 13.000176429748535, + 11.000703811645508, + 1.0, + 11.000703811645508, + 11.000703811645508, + 1.5849624872207642, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + + +def test_gpu(): + def _create_schedule(): + n = m = k = 512 + func = te.create_prim_func(te_workload.matmul(n=n, m=m, k=k)) + sch = tir.Schedule(func, debug_mask="all") + c = sch.get_block("C") + c_local = sch.cache_write(c, 0, "local") + i, j, k = sch.get_loops(c) + # pylint: disable=invalid-name + i0, i1, i2, i3, i4 = sch.split(i, factors=[None, 1, 16, 32, 1]) # outer: 1 + j0, j1, j2, j3, j4 = sch.split(j, factors=[None, 4, 1, 1, 16]) # outer: 8 + k0, k1, k2 = sch.split(k, factors=[None, 1, 2]) # outer: 256 + # pylint: enable=invalid-name + # fmt: off + sch.reorder( + i0, j0, # S + i1, j1, # S + i2, j2, # S + k0, # R + k1, # R + i3, j3, # S + k2, # R + i4, j4, # S + ) + # fmt: on + # thread binding + i0_j0 = sch.fuse(i0, j0) + i1_j1 = sch.fuse(i1, j1) + i2_j2 = sch.fuse(i2, j2) + sch.bind(i0_j0, "blockIdx.x") + sch.bind(i1_j1, "vthread.x") + sch.bind(i2_j2, "threadIdx.x") + # fusion + sch.reverse_compute_at(c_local, i2_j2) + # cache read 'A' + a_shared = sch.cache_read(c, 1, "shared") + sch.compute_at(a_shared, k0) + _, _, _, _, a_i, a_j = sch.get_loops(a_shared) + a_ij = sch.fuse(a_i, a_j) + _, a_j = sch.split(a_ij, factors=[None, 16]) # outer: 64 + sch.bind(a_j, "threadIdx.x") + # cache read 'B' + b_shared = sch.cache_read(c, 2, "shared") + sch.compute_at(b_shared, k0) + _, _, _, _, b_i, b_j = sch.get_loops(b_shared) + b_ij = sch.fuse(b_i, b_j) + _, b_j = sch.split(b_ij, factors=[None, 16]) # outer: 8 + sch.bind(b_j, "threadIdx.x") + # auto unroll + sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32", 1024)) + sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1)) + return sch + + extractor = ms.feature_extractor.PerStoreFeature() + (feature,) = extractor.extract_from( + _make_context(tvm.target.Target("cuda")), + candidates=[_make_candidate(_create_schedule)], + ) + feature = feature.numpy() + assert feature.shape == (4, N_FEATURES) + ### Check feature[0]: BufferStore(A_shared) <= A[...] + f = feature[0] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 24.000000085991324, + 24.000000085991324, + 24.000000085991324, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer A + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 25.000000042995662, + 20.000001375860553, + 23.00000017198264, + 14.000088052430122, + 1.0, + 0.0, + 0.0, + 18.00000550343433, + 20.00562591970089, + 2.321928094887362, + 23.00000017198264, + 18.00000550343433, + 21.000000687930438, + 12.0003521774803, + 12.0003521774803, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer A.shared + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 25.000000042995662, + 12.0003521774803, + 23.00000017198264, + 9.002815015607053, + 1.0, + 0.0, + 0.0, + 6.022367813028454, + 11.98049663618346, + 8.005624549193879, + 17.000011006847668, + 4.087462841250339, + 15.000044026886828, + 1.584962500721156, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 12.0003521774803, + 27.000000010748916, + 17.000011006847668, + 6.022367813028454, + 23.00000017198264, + 2.584962500721156, + 10.001408, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[1]: BufferStore(B_shared) <= B[...] + f = feature[1] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 22.00000034396526, + 22.00000034396526, + 21.000000687930438, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 22.00000034396526, + 20.000001375860553, + 20.000001375860553, + 14.000088052430122, + 1.0, + 0.0, + 0.0, + 15.000044026886828, + 20.17555076886471, + 2.321928094887362, + 20.000001375860553, + 18.00000550343433, + 18.00000550343433, + 12.0003521774803, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer B.shared + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 1.0, + 0.0, + 22.00000034396526, + 9.002815015607053, + 20.000001375860553, + 3.169925001442312, + 1.0, + 0.0, + 0.0, + 3.169925001442312, + 10.001408194392809, + 8.005624549193879, + 14.000088052430122, + 1.584962500721156, + 12.0003521774803, + 0.044394119358453436, + 4.087462841250339, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 9.002815015607053, + 24.000000085991324, + 17.000011006847668, + 3.169925001442312, + 20.000001375860553, + 2.584962500721156, + 10.001408, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[2]: BufferStore(C_local) <= C_local[...] + A_shared[...] * B_shared[...] + f = feature[2] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[ + 0.0, + 27.000000010748916, + 27.000000010748916, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 28.000000005374456, + 28.000000005374456, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer B.shared + assert_allclose( + actual=f[57:75], + desired=[ + 1.0, + 0.0, + 0.0, + 29.00000000268723, + 9.002815015607053, + 23.00000017198264, + 3.169925001442312, + 1.0, + 0.0, + 0.0, + 5.044394119358453, + 7.651051691178929, + 5.044394119358453, + 24.000000085991324, + 4.087462841250339, + 18.00000550343433, + 0.32192809488736235, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C.local + assert_allclose( + actual=f[75:93], + desired=[ + 0.0, + 0.0, + 1.0, + 29.00000000268723, + 11.000704269011246, + 23.00000017198264, + 5.044394119358453, + 1.0, + 0.0, + 0.0, + 4.087462841250339, + 7.05528243550119, + 1.584962500721156, + 28.000000005374456, + 10.001408194392809, + 22.00000034396526, + 4.087462841250339, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Buffer A.shared + assert_allclose( + actual=f[93:111], + desired=[ + 1.0, + 0.0, + 0.0, + 29.00000000268723, + 12.0003521774803, + 19.00000275171979, + 9.002815015607053, + 1.0, + 0.0, + 0.0, + 1.0, + 3.700439718141092, + 4.087462841250339, + 25.000000042995662, + 8.005624549193879, + 15.000044026886828, + 5.044394119358453, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[ + 0.7097842504665767, + 0.7548801745187567, + 0.8775907547541741, + 0.9957389916154509, + 1.2446737395193135, + 1.493608487423176, + 1.7093103019954263, + 1.8031580276850985, + 1.9841832691827785, + 2.204648076869754, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 11.000704269011246, + 18.00000550343433, + 9.002815015607053, + 18.00000550343433, + 27.000000010748916, + 3.0, + 10.001408, + ], + rtol=1e-5, + atol=1e-5, + ) + ### Check feature[3]: BufferStore(C) <= C_local[...] + f = feature[3] + # Group 1.1: arith + assert_allclose( + actual=f[0:16], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.2: vectorize + assert_allclose( + actual=f[16:27], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.3: unroll + assert_allclose( + actual=f[27:38], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.4: parallel + assert_allclose( + actual=f[38:49], + desired=[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 1.5: is_gpu, blockIdx.x/y/z, threadIdx.x/y/z, vthread + assert_allclose( + actual=f[49:57], + desired=[1.0, 3.169925001442312, 1.0, 1.0, 4.087462841250339, 1.0, 1.0, 2.321928094887362], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.1: Buffer C + assert_allclose( + actual=f[57:75], + desired=[ + 0.0, + 1.0, + 0.0, + 20.000001375860553, + 20.000001375860553, + 14.000088052430122, + 14.000088052430122, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 21.000000687930438, + 21.000000687930438, + 15.000044026886828, + 15.000044026886828, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.2: Buffer C.local + assert_allclose( + actual=f[75:93], + desired=[ + 1.0, + 0.0, + 0.0, + 20.000001375860553, + 11.000704269011246, + 14.000088052430122, + 5.044394119358453, + 1.0, + 0.0, + 0.0, + 9.002815015607053, + 12.0003521774803, + 4.087462841250339, + 16.00002201361136, + 7.011227255423254, + 10.001408194392809, + 1.584962500721156, + 1.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.3: Dummy padding + assert_allclose( + actual=f[93:111], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.4: Dummy padding + assert_allclose( + actual=f[111:129], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 2.5: Dummy padding + assert_allclose( + actual=f[129:147], + desired=[ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + rtol=1e-5, + atol=1e-5, + ) + # Group 3: Arithmetic intensity + assert_allclose( + actual=f[147:157], + desired=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + rtol=1e-5, + atol=1e-5, + ) + # Group 4 & 5 + assert_allclose( + actual=f[157:164], + desired=[ + 20.000001375860553, + 18.00000550343433, + 1.0, + 18.00000550343433, + 18.00000550343433, + 2.584962500721156, + 10.001408, + ], + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_fusion() + test_gpu() diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index f508c7d252e1..0ace4d2bd02c 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -112,7 +112,7 @@ def test_meta_schedule_integration_extract_from_resnet(): layout="NHWC", dtype="float32", ) - extracted_tasks = ms.integration.extract_task(mod, target="llvm", params=params) + extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) assert len(extracted_tasks) == 30 diff --git a/tests/python/unittest/test_meta_schedule_mutator.py b/tests/python/unittest/test_meta_schedule_mutator.py new file mode 100644 index 000000000000..b4d94dc9a8e3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List, Optional + +import re + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T + +from tvm.meta_schedule.mutator import PyMutator +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import Schedule, Trace + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_mutator(): + class FancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + mutator = FancyMutator() + sch = Schedule(Matmul) + res = mutator.apply(sch.trace) + assert res is not None + new_sch = sch.copy() + res.apply_to_schedule(new_sch, remove_postproc=True) + assert_structural_equal(sch.mod, new_sch.mod) + + +def test_meta_schedule_mutator_as_string(): + class YetAnotherFancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + pass + + def __str__(self) -> str: + return f"YetAnotherFancyMutator({_get_hex_address(self.handle)})" + + mutator = YetAnotherFancyMutator() + pattern = re.compile(r"YetAnotherFancyMutator\(0x[a-f|0-9]*\)") + assert pattern.match(str(mutator)) + + +if __name__ == "__main__": + test_meta_schedule_mutator() + test_meta_schedule_mutator_as_string() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py new file mode 100644 index 000000000000..439a4c19dba1 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_compute_location.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateComputeLocation, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def add(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + # body + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("move"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([A_cached[vi, vj, vk]]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads([A_cached[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + + +# pylint: enable=invalid-name, no-member + + +def _sch(decision: int) -> Schedule: + sch = Schedule(add, debug_mask="all") + # pylint: disable=invalid-name + b0 = sch.get_block(name="move", func_name="main") + l1 = sch.sample_compute_location(block=b0, decision=decision) + sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateComputeLocation() + mutator.initialize_with_tune_context(TuneContext(mod=add, target=target)) + return mutator + + +def test_mutate_compute_location_add(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + ) + sch = _sch(decision=4) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-2]] + assert not decision == 4 + results.add(decision) + assert len(results) == 9 + + +if __name__ == """__main__""": + test_mutate_compute_location_add() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py new file mode 100644 index 000000000000..e263114ef60f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_parallel.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateParallel, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]], ann_val: int) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.parallel", ann_val=ann_val) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator: + mutator = MutateParallel(max_jobs_per_core) + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_parallel_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + max_jobs_per_core=256, + ) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ann_val=64, + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + ann_val = int(trace.insts[-1].inputs[1]) + results.add(ann_val) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {4, 32, 4096} + + +if __name__ == """__main__""": + test_mutate_parallel_matmul() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py new file mode 100644 index 000000000000..c10e2a96bda2 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_tile_size.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateTileSize, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + (d0,) = decisions + b0 = sch.get_block(name="C", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[8, 4, 8, 2]) + l23, l24 = sch.split(loop=l4, factors=[512, 1]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateTileSize() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_tile_size_matmul(): + mutator = _make_mutator( + target=Target("llvm --num-cores=16"), + ) + results = {} + sch = _sch(decisions=[[4, 32, 4, 1]]) + for _ in range(100): + trace = mutator.apply(sch.trace) + assert trace.insts[4].kind.name == "SamplePerfectTile" + decision = trace.decisions[trace.insts[4]] + decision = [int(x) for x in decision] + results[str(decision)] = decision + assert math.prod(decision) == 512 + assert len(results) > 15 + + +if __name__ == """__main__""": + test_mutate_tile_size_matmul() diff --git a/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py new file mode 100644 index 000000000000..3f3fbcafc0db --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator_mutate_unroll.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.mutator import MutateUnroll, Mutator +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir import Schedule + +# pylint: disable=invalid-name, no-member + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [512, 512]) + B = T.match_buffer(b, [512, 512]) + C = T.match_buffer(c, [512, 512]) + for i, j, k in T.grid(512, 512, 512): # type: ignore + with T.block("C"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) # type: ignore + with T.init(): + C[vi, vj] = 0.0 # type: ignore + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=invalid-name, no-member + + +def _sch(decisions: List[List[int]]) -> Schedule: + sch = Schedule(matmul, debug_mask="all") + # pylint: disable=invalid-name + d0, d1, d2 = decisions + b0 = sch.get_block(name="C", func_name="main") + root = sch.get_block(name="root", func_name="main") + sch.get_consumers(block=b0) + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8 = sch.sample_perfect_tile( + loop=l2, + n=4, + max_innermost_factor=64, + decision=d0, + ) + l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8]) + v13, v14, v15, v16 = sch.sample_perfect_tile( + loop=l3, + n=4, + max_innermost_factor=64, + decision=d1, + ) + l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16]) + v21, v22 = sch.sample_perfect_tile( + loop=l4, + n=2, + max_innermost_factor=64, + decision=d2, + ) + l23, l24 = sch.split(loop=l4, factors=[v21, v22]) + sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20) + sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True) + v57 = sch.sample_categorical( + candidates=[0, 16, 64, 512], + probs=[0.25, 0.25, 0.25, 0.25], + decision=0, + ) + sch.annotate(block_or_loop=root, ann_key="meta_schedule.unroll_explicit", ann_val=v57) + # pylint: enable=invalid-name + return sch + + +def _make_mutator(target: Target) -> Mutator: + mutator = MutateUnroll() + mutator.initialize_with_tune_context(TuneContext(mod=matmul, target=target)) + return mutator + + +def test_mutate_unroll_matmul(): + mutator = _make_mutator(target=Target("llvm --num-cores=16")) + sch = _sch( + decisions=[ + [4, 32, 4, 1], + [8, 4, 8, 2], + [512, 1], + ], + ) + results = set() + for _ in range(100): + trace = mutator.apply(sch.trace) + decision = trace.decisions[trace.insts[-2]] + results.add(decision) + if len(results) == 3: + break + assert len(results) == 3 + assert results == {1, 2, 3} + + +if __name__ == """__main__""": + test_mutate_unroll_matmul() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index b78e67817ebf..bf539403ca89 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -22,14 +22,14 @@ import pytest import tvm +from tvm.tir.schedule import BlockRV, Schedule from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule from tvm.meta_schedule.space_generator import PostOrderApply from tvm.script import tir as T from tvm.target import Target -from tvm.tir.schedule import BlockRV, Schedule - +from tvm import register_func # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, # fmt: off @@ -50,6 +50,42 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] +@tvm.script.ir_module +class MatmulCustomized: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class MatmulCustomizedNoneRule: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + T.block_attr({"schedule_rule": "None"}) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "None"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + @tvm.script.ir_module class DuplicateMatmul: @T.prim_func @@ -102,7 +138,7 @@ def main(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [1024, 1024], dtype="float32") D = T.match_buffer(d, [1024, 1024], dtype="float32") # body - # with tir.block("root") + # with T.block("root") B = T.alloc_buffer([1024, 1024], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): with T.block("A"): @@ -120,6 +156,209 @@ def main(a: T.handle, d: T.handle) -> None: D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) +# with T.block("root"): + +# with T.block("A"): +# # template: meta_schedule.testing.some_rule +# ... +# with T.block("B"): +# # ReLU +# ... +# with T.block("C"): +# # bias_add +# ... + + + +@tvm.script.ir_module +class Conv2d_Winograd: + @T.prim_func + def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_winograd: T.handle) -> None: + # function attr dict + T.func_attr({"layout_free_placeholders": [var_placeholder_1]}) + placeholder = T.match_buffer(var_placeholder, [1, 14, 14, 128], elem_offset=0, align=128, offset_factor=1) + placeholder_1 = T.match_buffer(var_placeholder_1, [6, 6, 128, 128], elem_offset=0, align=128, offset_factor=1) + conv2d_winograd = T.match_buffer(var_conv2d_winograd, [1, 12, 12, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space.winograd"}) + data_pad = T.alloc_buffer([1, 16, 16, 128], elem_offset=0, align=128, offset_factor=1) + input_tile = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + B = T.alloc_buffer([6, 6], elem_offset=0, align=128, offset_factor=1) + data_pack = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1) + inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1) + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 16, 16, 128): + with T.block("data_pad"): + T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) + T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) + T.block_attr({ + "schedule_rule": "None", + }) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32") + for eps, nu, p, ci in T.grid(6, 6, 9, 128): + with T.block("input_tile"): + T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]]) + T.writes([input_tile[eps, nu, p, ci]]) + T.block_attr({ + "schedule_rule": "None", + }) + input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci] + for i, j in T.grid(6, 6): + with T.block("B"): + T.writes([B[i, j]]) + T.block_attr({ + "const_matrix" : True, + "schedule_rule": "None", + }) + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): + with T.block("data_pack"): + eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap("SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]) + T.reads([data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[T.min(r_a, r_b):(T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))), T.min(eps_1, nu_1):(T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)))]]) + T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) + T.block_attr({ + "auto_scheduler_simplify_const_tensor_indices":["eps", "nu", "r_a", "r_b"], + "schedule_rule": "None", + }) + with T.init(): + data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) + data_pack[eps_1, nu_1, p_1, ci_1] = (data_pack[eps_1, nu_1, p_1, ci_1] + ((input_tile[r_a, r_b, p_1, ci_1]*B[r_a, eps_1])*B[r_b, nu_1])) + for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): + with T.block("bgemm"): + eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) + T.reads([bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2]]) + T.writes([bgemm[eps_2, nu_2, p_2, co]]) + T.block_attr({ + "schedule_rule": "None", + }) + with T.init(): + bgemm[eps_2, nu_2, p_2, co] = T.float32(0) + bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2])) + for i_1, j_1 in T.grid(6, 4): + with T.block("A"): + T.writes([A[i_1, j_1]]) + T.block_attr({ + "const_matrix" : True, + "schedule_rule": "None", + }) + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): + with T.block("inverse"): + vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap("SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]) + T.reads([inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[T.min(r_a_1, r_b_1):(T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))), T.min(vh, vw):(T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw)))]]) + T.writes([inverse[vh, vw, p_3, co_1]]) + T.block_attr({ + "schedule_rule": "None", + "auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], + }) + with T.init(): + inverse[vh, vw, p_3, co_1] = T.float32(0) + inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw])) + for n, h, w, co_2 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]]) + T.writes([conv2d_winograd[n, h, w, co_2]]) + T.block_attr({ + "schedule_rule": "None" + }) + conv2d_winograd[n, h, w, co_2] = inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2] + +@tvm.script.ir_module +class Conv2d_Winograd_Cuda: + @T.prim_func + def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_winograd: T.handle) -> None: + # function attr dict + T.func_attr({"layout_free_placeholders": [var_placeholder_1]}) + placeholder = T.match_buffer(var_placeholder, [1, 14, 14, 128], elem_offset=0, align=128, offset_factor=1) + placeholder_1 = T.match_buffer(var_placeholder_1, [6, 6, 128, 128], elem_offset=0, align=128, offset_factor=1) + conv2d_winograd = T.match_buffer(var_conv2d_winograd, [1, 12, 12, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + data_pad = T.alloc_buffer([1, 16, 16, 128], elem_offset=0, align=128, offset_factor=1) + input_tile = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + B = T.alloc_buffer([6, 6], elem_offset=0, align=128, offset_factor=1) + data_pack = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1) + A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1) + inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1) + for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 16, 16, 128): + with T.block("data_pad"): + T.block_attr({ + "schedule_rule": "None", + }) + T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) + T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32") + for eps, nu, p, ci in T.grid(6, 6, 9, 128): + with T.block("input_tile"): + T.block_attr({ + "schedule_rule": "None", + }) + T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]]) + T.writes([input_tile[eps, nu, p, ci]]) + input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci] + for i, j in T.grid(6, 6): + with T.block("B"): + T.writes([B[i, j]]) + T.block_attr({ + "const_matrix":True, + "schedule_rule": "None", + }) + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): + with T.block("data_pack"): + eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap("SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]) + T.reads([data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[T.min(r_a, r_b):(T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))), T.min(eps_1, nu_1):(T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)))]]) + T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) + T.block_attr({ + "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"], + "schedule_rule": "None", + }) + with T.init(): + data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) + data_pack[eps_1, nu_1, p_1, ci_1] = (data_pack[eps_1, nu_1, p_1, ci_1] + ((input_tile[r_a, r_b, p_1, ci_1]*B[r_a, eps_1])*B[r_b, nu_1])) + for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): + with T.block("bgemm"): + T.block_attr({ + "schedule_rule": "None", + }) + eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) + T.reads([bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2]]) + T.writes([bgemm[eps_2, nu_2, p_2, co]]) + with T.init(): + bgemm[eps_2, nu_2, p_2, co] = T.float32(0) + bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2])) + for i_1, j_1 in T.grid(6, 4): + with T.block("A"): + T.writes([A[i_1, j_1]]) + T.block_attr({ + "const_matrix":True, + "schedule_rule": "None", + }) + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): + with T.block("inverse"): + vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap("SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]) + T.reads([inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[T.min(r_a_1, r_b_1):(T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))), T.min(vh, vw):(T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw)))]]) + T.writes([inverse[vh, vw, p_3, co_1]]) + T.block_attr({ + "auto_scheduler_simplify_const_tensor_indices":["vh", "vw", "r_a", "r_b"], + "schedule_rule": "None", + }) + with T.init(): + inverse[vh, vw, p_3, co_1] = T.float32(0) + inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw])) + for n, h, w, co_2 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + T.block_attr({ + "schedule_rule": "None", + }) + T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]]) + T.writes([conv2d_winograd[n, h, w, co_2]]) + conv2d_winograd[n, h, w, co_2] = inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2] + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -338,5 +577,437 @@ def correct_trace(a, b, c, d): ) +def test_meta_schedule_post_order_apply_custom_search_space(): + @register_func("tvm.meta_schedule.test.custom_search_space") + def custom_search_space_func(sch: Schedule, block: BlockRV): + raise ValueError("Customized search space triggered!") + + mod = MatmulCustomized + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises(ValueError, match="Customized search space triggered!"): + _ = post_order_apply.generate_design_space(mod) + + +class DontCallThisRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + print(sch.get(block)) + raise RuntimeError("This schedule rule should not be called!") + + +def test_meta_schedule_post_order_apply_custom_search_space_none_rule(): + mod = MatmulCustomizedNoneRule + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + _ = post_order_apply.generate_design_space(mod) + + +@pytest.mark.xfail # for compute_at bug +def test_meta_schedule_post_order_apply_custom_search_space_winograd(): + @register_func("tvm.meta_schedule.test.custom_search_space.winograd") + def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Schedule]: + b1 = sch.get_block(name="A") + sch.compute_inline(block=b1) + b2 = sch.get_block(name="B") + sch.compute_inline(block=b2) + b3 = sch.get_block(name="inverse") + l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3) + sch.unroll(loop=l4) + sch.unroll(loop=l5) + sch.unroll(loop=l8) + sch.unroll(loop=l9) + v10, v11 = sch.sample_perfect_tile(n=2, loop=l6, max_innermost_factor=64, decision=[1, 9]) + l12, l13 = sch.split(loop=l6, factors=[v10, v11]) + v14, v15 = sch.sample_perfect_tile(n=2, loop=l7, max_innermost_factor=64, decision=[2, 64]) + l16, l17 = sch.split(loop=l7, factors=[v14, v15]) + sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9) + b18 = sch.get_block(name="data_pack") + l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18) + sch.unroll(loop=l19) + sch.unroll(loop=l20) + sch.unroll(loop=l23) + sch.unroll(loop=l24) + v25, v26 = sch.sample_perfect_tile(n=2, loop=l21, max_innermost_factor=64, decision=[9, 1]) + l27, l28 = sch.split(loop=l21, factors=[v25, v26]) + v29, v30 = sch.sample_perfect_tile(n=2, loop=l22, max_innermost_factor=64, decision=[32, 4]) + l31, l32 = sch.split(loop=l22, factors=[v29, v30]) + sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24) + b33 = sch.get_block(name="bgemm") + b34 = sch.cache_write(block=b33, write_buffer_index=0, storage_scope="global") + b33, b34 = b34, b33 + l35, l36, l37, l38, l39 = sch.get_loops(block=b34) + v40, v41, v42, v43 = sch.sample_perfect_tile( + n=4, loop=l35, max_innermost_factor=64, decision=[1, 2, 3, 1] + ) + l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43]) + v48, v49, v50, v51 = sch.sample_perfect_tile( + n=4, loop=l36, max_innermost_factor=64, decision=[1, 1, 1, 6] + ) + l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51]) + v56, v57, v58, v59 = sch.sample_perfect_tile( + n=4, loop=l37, max_innermost_factor=64, decision=[1, 1, 1, 9] + ) + l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59]) + v64, v65, v66, v67 = sch.sample_perfect_tile( + n=4, loop=l38, max_innermost_factor=64, decision=[2, 1, 16, 4] + ) + l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67]) + v72, v73 = sch.sample_perfect_tile(n=2, loop=l39, max_innermost_factor=64, decision=[16, 8]) + l74, l75 = sch.split(loop=l39, factors=[v72, v73]) + sch.reorder( + l44, l52, l60, l68, l45, l53, l61, l69, l74, l46, l54, l62, l70, l75, l47, l55, l63, l71 + ) + sch.reverse_compute_at(block=b33, loop=l69, preserve_unit_loops=True) + b76 = sch.get_block(name="root") + sch.annotate(block_or_loop=b76, ann_key="auto_parallel_extent", ann_val=64) + sch.annotate(block_or_loop=b76, ann_key="auto_vectorize_extent", ann_val=32) + v77 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1 + ) + sch.annotate(block_or_loop=b76, ann_key="auto_unroll_explicit", ann_val=v77) + + b78 = sch.get_block(name="input_tile") + (b79,) = sch.get_consumers(block=b78) + l80 = sch.sample_compute_location(block=b79, decision=4) + sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) + + b81 = sch.get_block(name="data_pad") + (b82,) = sch.get_consumers(block=b81) + l83 = sch.sample_compute_location(block=b82, decision=-2) + sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True) + return [sch] + + mod = Conv2d_Winograd + + # Add annotation + sch = Schedule(mod) + sch.annotate( + sch.get_block("root"), + "schedule_rule", + "tvm.meta_schedule.test.custom_search_space.winograd", + ) + mod = sch.mod + context = TuneContext( + mod=mod, + target=Target("llvm --num-cores=16"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + (sch,) = schs + assert str(sch.trace) == "\n".join( + [ + 'b0 = sch.get_block(name="data_pad", func_name="main")', + 'b1 = sch.get_block(name="input_tile", func_name="main")', + 'b2 = sch.get_block(name="B", func_name="main")', + 'b3 = sch.get_block(name="data_pack", func_name="main")', + 'b4 = sch.get_block(name="bgemm", func_name="main")', + 'b5 = sch.get_block(name="A", func_name="main")', + 'b6 = sch.get_block(name="inverse", func_name="main")', + 'b7 = sch.get_block(name="conv2d_winograd", func_name="main")', + 'b8 = sch.get_block(name="root", func_name="main")', + 'b9 = sch.get_block(name="A", func_name="main")', + "sch.compute_inline(block=b9)", + 'b10 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b10)", + 'b11 = sch.get_block(name="inverse", func_name="main")', + "l12, l13, l14, l15, l16, l17 = sch.get_loops(block=b11)", + "sch.unroll(loop=l12)", + "sch.unroll(loop=l13)", + "sch.unroll(loop=l16)", + "sch.unroll(loop=l17)", + "v18, v19 = sch.sample_perfect_tile(loop=l14, n=2, max_innermost_factor=64, decision=[1, 9])", + "l20, l21 = sch.split(loop=l14, factors=[v18, v19])", + "v22, v23 = sch.sample_perfect_tile(loop=l15, n=2, max_innermost_factor=64, decision=[2, 64])", + "l24, l25 = sch.split(loop=l15, factors=[v22, v23])", + "sch.reorder(l20, l24, l21, l25, l12, l13, l16, l17)", + 'b26 = sch.get_block(name="data_pack", func_name="main")', + "l27, l28, l29, l30, l31, l32 = sch.get_loops(block=b26)", + "sch.unroll(loop=l27)", + "sch.unroll(loop=l28)", + "sch.unroll(loop=l31)", + "sch.unroll(loop=l32)", + "v33, v34 = sch.sample_perfect_tile(loop=l29, n=2, max_innermost_factor=64, decision=[9, 1])", + "l35, l36 = sch.split(loop=l29, factors=[v33, v34])", + "v37, v38 = sch.sample_perfect_tile(loop=l30, n=2, max_innermost_factor=64, decision=[32, 4])", + "l39, l40 = sch.split(loop=l30, factors=[v37, v38])", + "sch.reorder(l35, l39, l36, l40, l27, l28, l31, l32)", + 'b41 = sch.get_block(name="bgemm", func_name="main")', + 'b42 = sch.cache_write(block=b41, write_buffer_index=0, storage_scope="global")', + "l43, l44, l45, l46, l47 = sch.get_loops(block=b41)", + "v48, v49, v50, v51 = sch.sample_perfect_tile(loop=l43, n=4, max_innermost_factor=64, decision=[1, 2, 3, 1])", + "l52, l53, l54, l55 = sch.split(loop=l43, factors=[v48, v49, v50, v51])", + "v56, v57, v58, v59 = sch.sample_perfect_tile(loop=l44, n=4, max_innermost_factor=64, decision=[1, 1, 1, 6])", + "l60, l61, l62, l63 = sch.split(loop=l44, factors=[v56, v57, v58, v59])", + "v64, v65, v66, v67 = sch.sample_perfect_tile(loop=l45, n=4, max_innermost_factor=64, decision=[1, 1, 1, 9])", + "l68, l69, l70, l71 = sch.split(loop=l45, factors=[v64, v65, v66, v67])", + "v72, v73, v74, v75 = sch.sample_perfect_tile(loop=l46, n=4, max_innermost_factor=64, decision=[2, 1, 16, 4])", + "l76, l77, l78, l79 = sch.split(loop=l46, factors=[v72, v73, v74, v75])", + "v80, v81 = sch.sample_perfect_tile(loop=l47, n=2, max_innermost_factor=64, decision=[16, 8])", + "l82, l83 = sch.split(loop=l47, factors=[v80, v81])", + "sch.reorder(l52, l60, l68, l76, l53, l61, l69, l77, l82, l54, l62, l70, l78, l83, l55, l63, l71, l79)", + "sch.reverse_compute_at(block=b42, loop=l77, preserve_unit_loops=True)", + 'b84 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b84, ann_key="auto_parallel_extent", ann_val=64)', + 'sch.annotate(block_or_loop=b84, ann_key="auto_vectorize_extent", ann_val=32)', + "v85 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1)", + 'sch.annotate(block_or_loop=b84, ann_key="auto_unroll_explicit", ann_val=v85)', + 'b86 = sch.get_block(name="input_tile", func_name="main")', + "l87 = sch.sample_compute_location(block=b86, decision=-1)", + "sch.compute_at(block=b86, loop=l87, preserve_unit_loops=True)", + 'b88 = sch.get_block(name="data_pad", func_name="main")', + "l89 = sch.sample_compute_location(block=b88, decision=-1)", + "sch.compute_at(block=b88, loop=l89, preserve_unit_loops=True)", + ], + ) + + +@pytest.mark.xfail # for compute_at bug +def test_meta_schedule_post_order_apply_custom_search_space_winograd_cuda(): + @register_func("tvm.meta_schedule.test.custom_search_space.winograd.cuda") + def custom_search_space_winograd_func_cuda(sch: Schedule, block: BlockRV) -> List[Schedule]: + b1 = sch.get_block(name="inverse") + l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) + sch.unroll(loop=l2) + sch.unroll(loop=l3) + sch.unroll(loop=l6) + sch.unroll(loop=l7) + v8, v9 = sch.sample_perfect_tile(n=2, loop=l4, max_innermost_factor=64, decision=[3, 3]) + l10, l11 = sch.split(loop=l4, factors=[v8, v9]) + v12, v13 = sch.sample_perfect_tile(n=2, loop=l5, max_innermost_factor=64, decision=[2, 64]) + l14, l15 = sch.split(loop=l5, factors=[v12, v13]) + sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) + b16 = sch.get_block(name="data_pack") + l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) + sch.unroll(loop=l17) + sch.unroll(loop=l18) + sch.unroll(loop=l21) + sch.unroll(loop=l22) + v23, v24 = sch.sample_perfect_tile(n=2, loop=l19, max_innermost_factor=64, decision=[3, 3]) + l25, l26 = sch.split(loop=l19, factors=[v23, v24]) + v27, v28 = sch.sample_perfect_tile(n=2, loop=l20, max_innermost_factor=64, decision=[64, 2]) + l29, l30 = sch.split(loop=l20, factors=[v27, v28]) + sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22) + b31 = sch.get_block(name="bgemm") + b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") + b31, b32 = b32, b31 + l33, l34, l35, l36, l37 = sch.get_loops(block=b32) + v38, v39, v40, v41, v42 = sch.sample_perfect_tile( + n=5, loop=l33, max_innermost_factor=64, decision=[1, 1, 1, 1, 6] + ) + l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) + v48, v49, v50, v51, v52 = sch.sample_perfect_tile( + n=5, loop=l34, max_innermost_factor=64, decision=[1, 1, 1, 3, 2] + ) + l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) + v58, v59, v60, v61, v62 = sch.sample_perfect_tile( + n=5, loop=l35, max_innermost_factor=64, decision=[3, 1, 1, 1, 3] + ) + l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) + v68, v69, v70, v71, v72 = sch.sample_perfect_tile( + n=5, loop=l36, max_innermost_factor=64, decision=[4, 2, 1, 4, 4] + ) + l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) + v78, v79, v80 = sch.sample_perfect_tile( + n=3, loop=l37, max_innermost_factor=64, decision=[32, 1, 4] + ) + l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) + sch.reorder( + l43, + l53, + l63, + l73, + l44, + l54, + l64, + l74, + l45, + l55, + l65, + l75, + l81, + l82, + l46, + l56, + l66, + l76, + l83, + l47, + l57, + l67, + l77, + ) + l84 = sch.fuse(l43, l53, l63, l73) + sch.bind(loop=l84, thread_axis="blockIdx.x") + l85 = sch.fuse(l44, l54, l64, l74) + sch.bind(loop=l85, thread_axis="vthread.x") + l86 = sch.fuse(l45, l55, l65, l75) + sch.bind(loop=l86, thread_axis="threadIdx.x") + b87 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) + l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b87) + l96 = sch.fuse(l92, l93, l94, l95) + v97, v98 = sch.sample_perfect_tile( + n=2, loop=l96, max_innermost_factor=4, decision=[1536, 3] + ) + l99, l100 = sch.split(loop=l96, factors=[v97, v98]) + sch.vectorize(loop=l100) + sch.annotate(block_or_loop=l99, ann_key="loop_type", ann_val="lazy_cooperative_fetch") + b101 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) + l102, l103, l104, l105, l106, l107, l108, l109 = sch.get_loops(block=b101) + l110 = sch.fuse(l106, l107, l108, l109) + v111, v112 = sch.sample_perfect_tile( + n=2, loop=l110, max_innermost_factor=4, decision=[432, 1] + ) + l113, l114 = sch.split(loop=l110, factors=[v111, v112]) + sch.vectorize(loop=l114) + sch.annotate(block_or_loop=l113, ann_key="loop_type", ann_val="lazy_cooperative_fetch") + sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True) + b115 = sch.get_block(name="input_tile") + (b116,) = sch.get_consumers(block=b115) + l117, l118, l119, l120, l121, l122, l123, l124 = sch.get_loops(block=b116) + sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) + sch.set_scope(block=b115, buffer_index=0, storage_scope="local") + b125 = sch.get_block(name="A") + sch.compute_inline(block=b125) + b126 = sch.get_block(name="B") + sch.compute_inline(block=b126) + b127 = sch.get_block(name="data_pad") + sch.compute_inline(block=b127) + b128 = sch.get_block(name="root") + v129 = sch.sample_categorical( + candidates=[0, 16, 64, 512, 1024], probs=[0.2, 0.2, 0.2, 0.2, 0.2], decision=0 + ) + sch.annotate(block_or_loop=b128, ann_key="auto_unroll_explicit", ann_val=v129) + return [sch] + + mod = Conv2d_Winograd_Cuda + + # Add annotation + sch = Schedule(mod) + sch.annotate( + sch.get_block("root"), + "schedule_rule", + "tvm.meta_schedule.test.custom_search_space.winograd.cuda", + ) + mod = sch.mod + context = TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3070"), + task_name="Custom Search Space Task", + sch_rules=[DontCallThisRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + (sch,) = schs + assert str(sch.trace) == "\n".join( + [ + 'b0 = sch.get_block(name="data_pad", func_name="main")', + 'b1 = sch.get_block(name="input_tile", func_name="main")', + 'b2 = sch.get_block(name="B", func_name="main")', + 'b3 = sch.get_block(name="data_pack", func_name="main")', + 'b4 = sch.get_block(name="bgemm", func_name="main")', + 'b5 = sch.get_block(name="A", func_name="main")', + 'b6 = sch.get_block(name="inverse", func_name="main")', + 'b7 = sch.get_block(name="conv2d_winograd", func_name="main")', + 'b8 = sch.get_block(name="root", func_name="main")', + 'b9 = sch.get_block(name="inverse", func_name="main")', + "l10, l11, l12, l13, l14, l15 = sch.get_loops(block=b9)", + "sch.unroll(loop=l10)", + "sch.unroll(loop=l11)", + "sch.unroll(loop=l14)", + "sch.unroll(loop=l15)", + "v16, v17 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64, decision=[3, 3])", + "l18, l19 = sch.split(loop=l12, factors=[v16, v17])", + "v20, v21 = sch.sample_perfect_tile(loop=l13, n=2, max_innermost_factor=64, decision=[2, 64])", + "l22, l23 = sch.split(loop=l13, factors=[v20, v21])", + "sch.reorder(l18, l22, l19, l23, l10, l11, l14, l15)", + 'b24 = sch.get_block(name="data_pack", func_name="main")', + "l25, l26, l27, l28, l29, l30 = sch.get_loops(block=b24)", + "sch.unroll(loop=l25)", + "sch.unroll(loop=l26)", + "sch.unroll(loop=l29)", + "sch.unroll(loop=l30)", + "v31, v32 = sch.sample_perfect_tile(loop=l27, n=2, max_innermost_factor=64, decision=[3, 3])", + "l33, l34 = sch.split(loop=l27, factors=[v31, v32])", + "v35, v36 = sch.sample_perfect_tile(loop=l28, n=2, max_innermost_factor=64, decision=[64, 2])", + "l37, l38 = sch.split(loop=l28, factors=[v35, v36])", + "sch.reorder(l33, l37, l34, l38, l25, l26, l29, l30)", + 'b39 = sch.get_block(name="bgemm", func_name="main")', + 'b40 = sch.cache_write(block=b39, write_buffer_index=0, storage_scope="local")', + "l41, l42, l43, l44, l45 = sch.get_loops(block=b39)", + "v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l41, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 6])", + "l51, l52, l53, l54, l55 = sch.split(loop=l41, factors=[v46, v47, v48, v49, v50])", + "v56, v57, v58, v59, v60 = sch.sample_perfect_tile(loop=l42, n=5, max_innermost_factor=64, decision=[1, 1, 1, 3, 2])", + "l61, l62, l63, l64, l65 = sch.split(loop=l42, factors=[v56, v57, v58, v59, v60])", + "v66, v67, v68, v69, v70 = sch.sample_perfect_tile(loop=l43, n=5, max_innermost_factor=64, decision=[3, 1, 1, 1, 3])", + "l71, l72, l73, l74, l75 = sch.split(loop=l43, factors=[v66, v67, v68, v69, v70])", + "v76, v77, v78, v79, v80 = sch.sample_perfect_tile(loop=l44, n=5, max_innermost_factor=64, decision=[4, 2, 1, 4, 4])", + "l81, l82, l83, l84, l85 = sch.split(loop=l44, factors=[v76, v77, v78, v79, v80])", + "v86, v87, v88 = sch.sample_perfect_tile(loop=l45, n=3, max_innermost_factor=64, decision=[32, 1, 4])", + "l89, l90, l91 = sch.split(loop=l45, factors=[v86, v87, v88])", + "sch.reorder(l51, l61, l71, l81, l52, l62, l72, l82, l53, l63, l73, l83, l89, l90, l54, l64, l74, l84, l91, l55, l65, l75, l85)", + "l92 = sch.fuse(l51, l61, l71, l81)", + 'sch.bind(loop=l92, thread_axis="blockIdx.x")', + "l93 = sch.fuse(l52, l62, l72, l82)", + 'sch.bind(loop=l93, thread_axis="vthread.x")', + "l94 = sch.fuse(l53, l63, l73, l83)", + 'sch.bind(loop=l94, thread_axis="threadIdx.x")', + 'b95 = sch.cache_read(block=b39, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b95, loop=l89, preserve_unit_loops=True)", + "l96, l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b95)", + "l104 = sch.fuse(l100, l101, l102, l103)", + "v105, v106 = sch.sample_perfect_tile(loop=l104, n=2, max_innermost_factor=4, decision=[1536, 3])", + "l107, l108 = sch.split(loop=l104, factors=[v105, v106])", + "sch.vectorize(loop=l108)", + 'sch.annotate(block_or_loop=l107, ann_key="loop_type", ann_val="lazy_cooperative_fetch")', + 'b109 = sch.cache_read(block=b39, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b109, loop=l89, preserve_unit_loops=True)", + "l110, l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b109)", + "l118 = sch.fuse(l114, l115, l116, l117)", + "v119, v120 = sch.sample_perfect_tile(loop=l118, n=2, max_innermost_factor=4, decision=[432, 1])", + "l121, l122 = sch.split(loop=l118, factors=[v119, v120])", + "sch.vectorize(loop=l122)", + 'sch.annotate(block_or_loop=l121, ann_key="loop_type", ann_val="lazy_cooperative_fetch")', + "sch.reverse_compute_at(block=b40, loop=l94, preserve_unit_loops=True)", + 'b123 = sch.get_block(name="input_tile", func_name="main")', + "b124, = sch.get_consumers(block=b123)", + "l125, l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b124)", + "sch.compute_at(block=b123, loop=l128, preserve_unit_loops=True)", + 'sch.set_scope(block=b123, buffer_index=0, storage_scope="local")', + 'b133 = sch.get_block(name="A", func_name="main")', + "sch.compute_inline(block=b133)", + 'b134 = sch.get_block(name="B", func_name="main")', + "sch.compute_inline(block=b134)", + 'b135 = sch.get_block(name="data_pad", func_name="main")', + "sch.compute_inline(block=b135)", + 'b136 = sch.get_block(name="root", func_name="main")', + "v137 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=0)", + 'sch.annotate(block_or_loop=b136, ann_key="auto_unroll_explicit", ann_val=v137)', + ] + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py new file mode 100644 index 000000000000..6e17e7bac3f2 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import PyPostproc +from tvm.meta_schedule.utils import _get_hex_address +from tvm.script import tir as T +from tvm.target.target import Target +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_postproc(): + class FancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + schedule_matmul(sch) + return True + + postproc = FancyPostproc() + mod = Matmul + sch = Schedule(mod) + assert postproc.apply(sch) + try: + tvm.ir.assert_structural_equal(sch.mod, mod) + raise Exception("The postprocessors did not change the schedule.") + except ValueError: + _check_correct(sch) + + +def test_meta_schedule_postproc_fail(): + class FailingPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + return False + + postproc = FailingPostproc() + sch = Schedule(Matmul) + assert not postproc.apply(sch) + + +def test_meta_schedule_postproc_as_string(): + class NotSoFancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" + + postproc = NotSoFancyPostproc() + pattern = re.compile(r"NotSoFancyPostproc\(0x[a-f|0-9]*\)") + assert pattern.match(str(postproc)) + + +if __name__ == "__main__": + test_meta_schedule_postproc() + test_meta_schedule_postproc_fail() + test_meta_schedule_postproc_as_string() diff --git a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py new file mode 100644 index 000000000000..d27e3e61084f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import DisallowDynamicLoop +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + DisallowDynamicLoop(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DynamicLoop: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + for k in T.serial(0, i): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_postproc_disallow_dynamic_loops(): + mod = Matmul + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_disallow_dynamic_loops_fail(): + mod = DynamicLoop + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_disallow_dynamic_loops() + test_postproc_disallow_dynamic_loops_fail() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py new file mode 100644 index 000000000000..41d662ab194e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteCooperativeFetch +from tvm.meta_schedule.testing import te_workload +from tvm.script import tir as T +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteCooperativeFetch(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class AfterRewrite0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + # with T.block("root") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + + +@tvm.script.ir_module +class AfterRewrite1: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0, ax1 in T.grid(256, 16): + with T.block("A_shared_wmma.matrix_a"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax1) + T.reads([A_shared[v0, v1]]) + T.writes([A_shared_wmma_matrix_a[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_a"}) + A_shared_wmma_matrix_a[v0, v1] = A_shared[v0, v1] + for ax0, ax1 in T.grid(16, 32): + with T.block("B_shared_wmma.matrix_b"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([B_shared[v0, v1]]) + T.writes([B_shared_wmma_matrix_b[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_b"}) + B_shared_wmma_matrix_b[v0, v1] = B_shared[v0, v1] + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_fill"}) + with T.init(): + for i0_1, i1_1 in T.grid(16, 16): + with T.block("C_init"): + i_init = T.axis.spatial(512, io * 16 + i0_1) + j_init = T.axis.spatial(512, jo * 16 + i1_1) + T.reads([]) + T.writes([C_local_wmma_accumulator[i_init, j_init]]) + C_local_wmma_accumulator[i_init, j_init] = T.float32(0) + for i0_1, i1_1, i2_1 in T.grid(16, 16, 16): + with T.block("C"): + i = T.axis.spatial(512, io * 16 + i0_1) + j = T.axis.spatial(512, jo * 16 + i1_1) + k = T.axis.reduce(512, ko * 16 + i2_1) + T.reads([C_local_wmma_accumulator[i, j], A_shared_wmma_matrix_a[i, k], B_shared_wmma_matrix_b[k, j]]) + T.writes([C_local_wmma_accumulator[i, j]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync"}) + C_local_wmma_accumulator[i, j] = C_local_wmma_accumulator[i, j] + T.cast(A_shared_wmma_matrix_a[i, k], "float32") * T.cast(B_shared_wmma_matrix_b[k, j], "float32") + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local_wmma.accumulator"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local_wmma_accumulator[v0, v1]]) + T.writes([C_local[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store"}) + C_local[v0, v1] = C_local_wmma_accumulator[v0, v1] + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + l2, l3, l4 = sch.get_loops(block=b0) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) + l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) + l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) + v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22) + sch.bind(loop=l33, thread_axis="threadIdx.x") + b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True) + _, _, _, _, l39, l40 = sch.get_loops(block=b34) + l41 = sch.fuse(l39, l40) + _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) + sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True) + _, _, _, _, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50) + _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) + sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0) + + +def test_rewrite_cooperative_fetch_tensor_core(): + mod = create_prim_func(te_workload.matmul_fp16(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + l1, l2, l3 = sch.get_loops(block=b0) + _, l5 = sch.split(loop=l1, factors=[32, 16]) + _, l7 = sch.split(loop=l2, factors=[32, 16]) + _, l9 = sch.split(loop=l3, factors=[32, 16]) + _, _, l12, _, l14, _ = sch.get_loops(block=b0) + sch.reorder(l12, l14, l5, l7, l9) + b16 = sch.blockize(loop=l5) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync") + sch.annotate(block_or_loop=b16, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill") + b17 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b17, ann_key="meta_schedule.tensor_core_enabled", ann_val="1") + b18 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="local") + b19 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="wmma.accumulator") + sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store") + l20, l21, l22 = sch.get_loops(block=b16) + v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=64, decision=[1, 2, 1, 1, 16]) + l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27]) + v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=64, decision=[1, 2, 8, 1, 2]) + l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37]) + v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=64, decision=[4, 8, 1]) + l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45]) + sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42) + l49 = sch.fuse(l28, l38) + sch.bind(loop=l49, thread_axis="blockIdx.x") + l50 = sch.fuse(l29, l39) + sch.bind(loop=l50, thread_axis="blockIdx.y") + l51 = sch.fuse(l30, l40) + sch.bind(loop=l51, thread_axis="threadIdx.y") + b52 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True) + _, _, _, _, l57, l58 = sch.get_loops(block=b52) + l59 = sch.fuse(l57, l58) + _, v61 = sch.sample_perfect_tile(loop=l59, n=2, max_innermost_factor=4, decision=[32768, 1]) + sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61) + b62 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b62, loop=l46, preserve_unit_loops=True) + _, _, _, _, l67, l68 = sch.get_loops(block=b62) + l69 = sch.fuse(l67, l68) + _, v71 = sch.sample_perfect_tile(loop=l69, n=2, max_innermost_factor=4, decision=[8192, 4]) + sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v71) + b72 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a") + b73 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b") + sch.compute_at(block=b72, loop=l48, preserve_unit_loops=True) + sch.compute_at(block=b73, loop=l48, preserve_unit_loops=True) + sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a") + sch.annotate(block_or_loop=b73, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b") + sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True) + sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, AfterRewrite1) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() + test_rewrite_cooperative_fetch_tensor_core() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..9734c5796042 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import RewriteParallelVectorizeUnroll +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Move_PUV: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1) + T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + + +@T.prim_func +def Move_PUV0(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") + B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") + # body + with T.block("root"): + for i0_j0_fused in T.parallel(0, 8192): + for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): + for k1_fused in T.vectorized(0, 32): + with T.block("move"): + vi = T.axis.spatial(1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(1024, k0 * 32 + k1_fused) + T.where( + (i0_j0_fused // 64 * 4 + i1) * 4 + i2 < 1024 + and (i0_j0_fused % 64 * 4 + j1) * 8 + j2 < 1024 + and k0 * 32 + k1_fused < 1024 + ) + T.reads([A[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable + + +def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): + postproc = RewriteParallelVectorizeUnroll() + sch = Schedule(Move_PUV) + assert postproc.apply(sch) + tvm.ir.assert_structural_equal(sch.mod["main"], Move_PUV0) + + +if __name__ == "__main__": + test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py new file mode 100644 index 000000000000..c935b6193f27 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -0,0 +1,394 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteReductionBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteReductionBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Before0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class After0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 32768): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) + v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): + for ax0_ax1_fused_2 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":2}) + B_shared[v0, v1] = B[v0, v1] + for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(2, 2, 16, 2): + with T.block("C_init"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3_init * 16 + i0_4_init) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3_init * 2 + i1_4_init) + T.reads([]) + T.writes([C_local[i, j]]) + C_local[i, j] = T.float32(0) + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C_update"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) + k = T.axis.reduce(512, i2_1 * 32 + i2_2) + T.reads([C_local[i, j], A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class Before1: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0, ax1 in T.grid(256, 16): + with T.block("A_shared_wmma.matrix_a"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax1) + T.reads([A_shared[v0, v1]]) + T.writes([A_shared_wmma_matrix_a[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_a"}) + A_shared_wmma_matrix_a[v0, v1] = A_shared[v0, v1] + for ax0, ax1 in T.grid(16, 32): + with T.block("B_shared_wmma.matrix_b"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([B_shared[v0, v1]]) + T.writes([B_shared_wmma_matrix_b[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_b"}) + B_shared_wmma_matrix_b[v0, v1] = B_shared[v0, v1] + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_fill"}) + with T.init(): + for i0_1, i1_1 in T.grid(16, 16): + with T.block("C_init"): + i_init = T.axis.spatial(512, io * 16 + i0_1) + j_init = T.axis.spatial(512, jo * 16 + i1_1) + T.reads([]) + T.writes([C_local_wmma_accumulator[i_init, j_init]]) + C_local_wmma_accumulator[i_init, j_init] = T.float32(0) + for i0_1, i1_1, i2_1 in T.grid(16, 16, 16): + with T.block("C"): + i = T.axis.spatial(512, io * 16 + i0_1) + j = T.axis.spatial(512, jo * 16 + i1_1) + k = T.axis.reduce(512, ko * 16 + i2_1) + T.reads([C_local_wmma_accumulator[i, j], A_shared_wmma_matrix_a[i, k], B_shared_wmma_matrix_b[k, j]]) + T.writes([C_local_wmma_accumulator[i, j]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync"}) + C_local_wmma_accumulator[i, j] = C_local_wmma_accumulator[i, j] + T.cast(A_shared_wmma_matrix_a[i, k], "float32") * T.cast(B_shared_wmma_matrix_b[k, j], "float32") + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local_wmma.accumulator"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local_wmma_accumulator[v0, v1]]) + T.writes([C_local[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store"}) + C_local[v0, v1] = C_local_wmma_accumulator[v0, v1] + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +@tvm.script.ir_module +class After1: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i0_0_4_init, i1_0_4_init in T.grid(16, 2): + with T.block("blockized_C_init"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4_init) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4_init) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1, i1_1 in T.grid(16, 16): + with T.block("C_init"): + i_init = T.axis.spatial(512, io * 16 + i0_1) + j_init = T.axis.spatial(512, jo * 16 + i1_1) + T.reads([]) + T.writes([C_local_wmma_accumulator[i_init, j_init]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_fill"}) + C_local_wmma_accumulator[i_init, j_init] = T.float32(0) + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0, ax1 in T.grid(256, 16): + with T.block("A_shared_wmma.matrix_a"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax1) + T.reads([A_shared[v0, v1]]) + T.writes([A_shared_wmma_matrix_a[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_a"}) + A_shared_wmma_matrix_a[v0, v1] = A_shared[v0, v1] + for ax0, ax1 in T.grid(16, 32): + with T.block("B_shared_wmma.matrix_b"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([B_shared[v0, v1]]) + T.writes([B_shared_wmma_matrix_b[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_b"}) + B_shared_wmma_matrix_b[v0, v1] = B_shared[v0, v1] + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C_update"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1, i1_1, i2_1 in T.grid(16, 16, 16): + with T.block("C"): + i = T.axis.spatial(512, io * 16 + i0_1) + j = T.axis.spatial(512, jo * 16 + i1_1) + k = T.axis.reduce(512, ko * 16 + i2_1) + T.reads([C_local_wmma_accumulator[i, j], A_shared_wmma_matrix_a[i, k], B_shared_wmma_matrix_b[k, j]]) + T.writes([C_local_wmma_accumulator[i, j]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync"}) + C_local_wmma_accumulator[i, j] = C_local_wmma_accumulator[i, j] + T.cast(A_shared_wmma_matrix_a[i, k], "float32") * T.cast(B_shared_wmma_matrix_b[k, j], "float32") + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local_wmma.accumulator"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local_wmma_accumulator[v0, v1]]) + T.writes([C_local[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store"}) + C_local[v0, v1] = C_local_wmma_accumulator[v0, v1] + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_reduction_block(): + mod = Before0 + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After0) + + +def test_rewrite_reduction_block_tensor_core(): + mod = Before1 + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After1) + + +if __name__ == "__main__": + test_rewrite_reduction_block() + test_rewrite_reduction_block_tensor_core() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensor_core.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensor_core.py new file mode 100644 index 000000000000..c11890aefa80 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_tensor_core.py @@ -0,0 +1,275 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteTensorCore +from tvm.script import tir as T +from tvm.target import Target +from tvm.meta_schedule.testing import tir_tensor_intrin + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteTensorCore(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + + +@tvm.script.ir_module +class Before0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C_local = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + # C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i0_0_4_init, i1_0_4_init in T.grid(16, 2): + with T.block("blockized_C_init"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4_init) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4_init) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1, i1_1 in T.grid(16, 16): + with T.block("C_init"): + i_init = T.axis.spatial(512, io * 16 + i0_1) + j_init = T.axis.spatial(512, jo * 16 + i1_1) + T.reads([]) + T.writes([C_local_wmma_accumulator[i_init, j_init]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_fill"}) + C_local_wmma_accumulator[i_init, j_init] = T.float32(0) + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0, ax1 in T.grid(256, 16): + with T.block("A_shared_wmma.matrix_a"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax1) + T.reads([A_shared[v0, v1]]) + T.writes([A_shared_wmma_matrix_a[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_a"}) + A_shared_wmma_matrix_a[v0, v1] = A_shared[v0, v1] + for ax0, ax1 in T.grid(16, 32): + with T.block("B_shared_wmma.matrix_b"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + i2_0_1 * 16 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([B_shared[v0, v1]]) + T.writes([B_shared_wmma_matrix_b[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_b"}) + B_shared_wmma_matrix_b[v0, v1] = B_shared[v0, v1] + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C_update"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1, i1_1, i2_1 in T.grid(16, 16, 16): + with T.block("C"): + i = T.axis.spatial(512, io * 16 + i0_1) + j = T.axis.spatial(512, jo * 16 + i1_1) + k = T.axis.reduce(512, ko * 16 + i2_1) + T.reads([C_local_wmma_accumulator[i, j], A_shared_wmma_matrix_a[i, k], B_shared_wmma_matrix_b[k, j]]) + T.writes([C_local_wmma_accumulator[i, j]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync"}) + C_local_wmma_accumulator[i, j] = C_local_wmma_accumulator[i, j] + T.cast(A_shared_wmma_matrix_a[i, k], "float32") * T.cast(B_shared_wmma_matrix_b[k, j], "float32") + for ax0, ax1 in T.grid(256, 32): + with T.block("C_local_wmma.accumulator"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + ax0) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + i0_0_2_i1_0_2_fused * 32 + ax1) + T.reads([C_local_wmma_accumulator[v0, v1]]) + T.writes([C_local[v0, v1]]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store"}) + C_local[v0, v1] = C_local_wmma_accumulator[v0, v1] + + +@tvm.script.ir_module +class After0: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: + s0 = T.var("int32") + s0_1 = T.var("int32") + s0_2 = T.var("int32") + s1 = T.var("int32") + s1_1 = T.var("int32") + s1_2 = T.var("int32") + A = T.match_buffer(var_A, [512, 512], dtype="float16") + B = T.match_buffer(var_B, [512, 512], dtype="float16") + C_local = T.match_buffer(var_C, [512, 512], dtype="float32") + # body + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.tensor_core_enabled":"1"}) + C_local_wmma_accumulator = T.alloc_buffer([512, 512], dtype="float32", scope="wmma.accumulator") + A_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float16", scope="shared") + A_shared_wmma_matrix_a = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_a") + B_shared_wmma_matrix_b = T.alloc_buffer([512, 512], dtype="float16", scope="wmma.matrix_b") + for i0_0_0_i1_0_0_fused in T.thread_binding(0, 1, thread="blockIdx.x"): + for i0_0_1_i1_0_1_fused in T.thread_binding(0, 4, thread="blockIdx.y"): + for i0_0_2_i1_0_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i0_0_4_init, i1_0_4_init in T.grid(16, 2): + with T.block("blockized_C_init"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4_init) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4_init) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1_0, i1_1_0 in T.grid(1, 1): + with T.block("blockized_C_init"): + i_inito = T.axis.spatial(1, 0) + j_inito = T.axis.spatial(1, 0) + T.reads([]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + C = T.match_buffer(C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + C.elem_offset % 256 // 16, T.float32(0), dtype="handle")) + for i2_0_0 in T.serial(0, 4): + for ax0_ax1_fused_0 in T.serial(0, 128): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("A_shared"): + v0 = T.axis.spatial(512, i0_0_1_i1_0_1_fused // 2 * 256 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) // 128) + v1 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 32 + ax0_ax1_fused_2) % 128) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding(0, 32, thread="threadIdx.x"): + for ax0_ax1_fused_3 in T.vectorized(0, 4): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0_0 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 256) + v1 = T.axis.spatial(512, i0_0_1_i1_0_1_fused % 2 * 256 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 256) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + B_shared[v0, v1] = B[v0, v1] + for i2_0_1, i0_0_3, i1_0_3, i2_0_2 in T.grid(8, 1, 1, 1): + for ax0_0, ax1_0 in T.grid(16, 1): + with T.block("blockized_A_shared_wmma.matrix_a"): + v0o = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + ax0_0) + v1o = T.axis.spatial(32, i2_0_0 * 8 + i2_0_1) + T.reads([A_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + T.writes([A_shared_wmma_matrix_a[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + A_1 = T.match_buffer(A_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", strides=[s1, s0], scope="shared", offset_factor=16) + C_1 = T.match_buffer(A_shared_wmma_matrix_a[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // 256 + C_1.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), A_1.data, A_1.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("blockized_B_shared_wmma.matrix_b"): + v0o = T.axis.spatial(32, i2_0_0 * 8 + i2_0_1) + v1o = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + ax1_0) + T.reads([B_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + T.writes([B_shared_wmma_matrix_b[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + A_2 = T.match_buffer(B_shared[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", strides=[s1_1, s0_1], scope="shared", offset_factor=16) + C_2 = T.match_buffer(B_shared_wmma_matrix_b[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_b", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // 256 + C_2.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), A_2.data, A_2.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "row_major", dtype="handle")) + for i0_0_4, i1_0_4 in T.grid(16, 2): + with T.block("blockized_C_update"): + io = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + i0_0_4) + jo = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + i1_0_4) + ko = T.axis.reduce(32, i2_0_0 * 8 + i2_0_1) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + for i0_1_0, i1_1_0, i2_1_0 in T.grid(1, 1, 1): + with T.block("blockized_C"): + io_1 = T.axis.spatial(1, 0) + jo_1 = T.axis.spatial(1, 0) + ko_1 = T.axis.reduce(1, 0) + T.reads([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16]]) + T.writes([C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16]]) + A_3 = T.match_buffer(A_shared_wmma_matrix_a[io * 16 : io * 16 + 16, ko * 16 : ko * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + B_1 = T.match_buffer(B_shared_wmma_matrix_b[ko * 16 : ko * 16 + 16, jo * 16 : jo * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_b", offset_factor=16) + C_3 = T.match_buffer(C_local_wmma_accumulator[io * 16 : io * 16 + 16, jo * 16 : jo * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, A_3.data, A_3.elem_offset // 256 + A_3.elem_offset % 256 // 16, B_1.data, B_1.elem_offset // 256 + B_1.elem_offset % 256 // 16, C_3.data, C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, dtype="handle")) + for ax0_0, ax1_0 in T.grid(16, 2): + with T.block("blockized_C_local_wmma.accumulator"): + v0o = T.axis.spatial(32, i0_0_1_i1_0_1_fused // 2 * 16 + ax0_0) + v1o = T.axis.spatial(32, i0_0_1_i1_0_1_fused % 2 * 16 + i0_0_2_i1_0_2_fused * 2 + ax1_0) + T.reads([C_local_wmma_accumulator[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + T.writes([C_local[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16]]) + A_4 = T.match_buffer(C_local_wmma_accumulator[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + C_4 = T.match_buffer(C_local[v0o * 16 : v0o * 16 + 16, v1o * 16 : v1o * 16 + 16], [16, 16], dtype="float32", strides=[s1_2, s0_2], offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(A_4.data, 16, 16, 16, A_4.elem_offset // 256 + A_4.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle")) + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_tensor_core(): + mod = Before0 + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After0) + + +if __name__ == "__main__": + test_rewrite_tensor_core() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py new file mode 100644 index 000000000000..9b39ad1bff3e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import RewriteUnboundBlock +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + RewriteUnboundBlock(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + + +@tvm.script.ir_module +class Before_cooperative_fetch: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i, j in T.grid(512, 512): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.ir_module +class After_cooperative_fetch: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + for i_j_fused_0 in T.thread_binding(0, 8192, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("C"): + vi = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) // 512) + vj = T.axis.spatial(512, (i_j_fused_0 * 32 + i_j_fused_1) % 512) + B[vi, vj] = A[vi, vj] + 1.0 + + +@tvm.script.ir_module +class Before_norm_bmn: + @T.prim_func + def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + C = T.alloc_buffer([1], dtype="float32") + for i0, i1, i2 in T.grid(1, 256, 256): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [i0, i1, i2]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0 in T.serial(1): + with T.block("D"): + b = T.axis.S(1, i0) + D[b] = T.sqrt(C[b], dtype="float32") + + +@tvm.script.ir_module +class After_norm_bmn: + @T.prim_func + def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + C = T.alloc_buffer([1], dtype="float32") + for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + for i1, i2 in T.grid(256, 256): + with T.block("C"): + b = T.axis.S(1, 0) + i, j = T.axis.remap("RR", [i1, i2]) + T.where(i0_fused_1 < 1) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.S(1, 0) + T.where(i0_fused_1 < 1) + D[b] = T.sqrt(C[b], dtype="float32") + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def test_rewrite_cooperative_fetch(): + mod = Before_cooperative_fetch + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After_cooperative_fetch) + + +def test_rewrite_norm_bmn(): + mod = Before_norm_bmn + target = _target() + ctx = _create_context(mod, target) + sch = tir.Schedule(mod, debug_mask="all") + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + tvm.ir.assert_structural_equal(sch.mod, After_norm_bmn) + + +if __name__ == "__main__": + test_rewrite_cooperative_fetch() + test_rewrite_norm_bmn() diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py new file mode 100644 index 000000000000..f3a318f91358 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -0,0 +1,427 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import VerifyGPUCode +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + VerifyGPUCode(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Conv2dCuda0: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda1: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([6400000], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda2: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512000], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda3: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 800000) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@T.prim_func +def GmmCuda0(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + +@T.prim_func +def GmmCuda1(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 0, + "meta_schedule.thread_extent_high_inclusive": 32, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + + +@T.prim_func +def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: + Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") + X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): + for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): + for i1_3_init, i2_4_init in T.grid(4, 2): + with T.block("Z_init"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) + T.reads() + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = T.float32(0) + for i3_0 in T.serial(4): + for ax0_ax1_ax2_fused_0 in T.serial(4): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in T.vectorized(2): + with T.block("X_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) + v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) + T.reads(X[v0, v1, v2]) + T.writes(X_shared[v0, v1, v2]) + X_shared[v0, v1, v2] = X[v0, v1, v2] + for ax0_ax1_ax2_fused_0 in T.serial(8): + for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): + with T.block("Y_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) + T.reads(Y[v0, v1, v2]) + T.writes(Y_shared[v0, v1, v2]) + Y_shared[v0, v1, v2] = Y[v0, v1, v2] + for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): + with T.block("Z_update"): + b = T.axis.spatial(1, 0) + i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) + j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) + k = T.axis.reduce(128, i3_0 * 32 + i3_2) + T.block_attr({ + "meta_schedule.thread_extent_low_inclusive": 1024, + "meta_schedule.thread_extent_high_inclusive": 1024, + }) + T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) + T.writes(Z_local[b, i, j]) + Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] + for ax0, ax1, ax2 in T.grid(1, 4, 2): + with T.block("Z_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) + v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) + T.reads(Z_local[v0, v1, v2]) + T.writes(Z[v0, v1, v2]) + Z[v0, v1, v2] = Z_local[v0, v1, v2] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant + + +def test_postproc_verify_gpu_0(): + mod = Conv2dCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_1(): + mod = Conv2dCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_2(): + mod = Conv2dCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_3(): + mod = Conv2dCuda3 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_4(): + mod = GmmCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_5(): + mod = GmmCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_6(): + mod = GmmCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_verify_gpu_0() + test_postproc_verify_gpu_1() + test_postproc_verify_gpu_2() + test_postproc_verify_gpu_3() + test_postproc_verify_gpu_4() + test_postproc_verify_gpu_5() + test_postproc_verify_gpu_6() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py new file mode 100644 index 000000000000..1d34d94bfe05 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re +from typing import List + +import tvm +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.script import tir as T +from tvm.tir.schedule import BlockRV, Schedule + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_schedule_rule(): + class FancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [sch] + + sch_rule = FancyScheduleRule() + mod = Matmul + sch = Schedule(mod) + res = sch_rule.apply(sch, block=sch.get_block("matmul")) + assert len(res) == 1 + try: + tvm.ir.assert_structural_equal(mod, res[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except ValueError: + _check_correct(res[0]) + + +def test_meta_schedule_schedule_rule_as_string(): + class YetStillSomeFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + pass + + sch_rule = YetStillSomeFancyScheduleRule() + pattern = re.compile(r"YetStillSomeFancyScheduleRule\(0x[a-f|0-9]*\)") + assert pattern.match(str(sch_rule)) + + +if __name__ == "__main__": + test_meta_schedule_schedule_rule() + test_meta_schedule_schedule_rule_as_string() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py new file mode 100644 index 000000000000..5a8031220354 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import add_rfactor +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.target import Target +from tvm.te.operation import create_prim_func + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l7, factor_axis=2)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l6, factor_axis=2)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + ], + ] + target = Target("llvm --num-cores=32") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=4, + m=4, + k=512, + ) + ), + target=target, + rule=add_rfactor(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py new file mode 100644 index 000000000000..1d8e8515f24f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -0,0 +1,384 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import ( + auto_inline, + auto_inline_after_tiling, +) +from tvm.meta_schedule.schedule_rule import ( + AutoInline, +) +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Conv2DBiasBnReLU: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads([compute_1[nn, ff, yy, xx], pad_temp[nn, rc, yy + ry, xx + rx], W[ff, rc, ry, rx]]) + T.writes([compute_1[nn, ff, yy, xx]]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bias_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i, j, k, l], B[j, 0, 0]]) + T.writes([bias_add[i, j, k, l]]) + bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_mul"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bias_add[i, j, k, l], bn_scale[j, 0, 0]]) + T.writes([bn_mul[i, j, k, l]]) + bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("bn_add"): + i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bn_mul[i, j, k, l], bn_offset[j, 0, 0]]) + T.writes([bn_add[i, j, k, l]]) + bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([bn_add[i0_2, i1_2, i2_2, i3_2]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0)) + + +@tvm.script.ir_module +class Conv2DBiasBnReLUInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): + with T.block("compute"): + nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) + T.reads([compute_1[nn, ff, yy, xx], pad_temp[nn, rc, yy + ry, xx + rx], W[ff, rc, ry, rx]]) + T.writes([compute_1[nn, ff, yy, xx]]) + with T.init(): + compute_1[nn, ff, yy, xx] = T.float32(0) + compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i0_2, i1_2, i2_2, i3_2], B[i1_2, 0, 0], bn_scale[i1_2, 0, 0], bn_offset[i1_2, 0, 0]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class NeedsInlinePaddingAndEpilogue: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") + compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): + with T.block("pad_temp"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([X[i0_1, i1_1, i2_1 - 1, i3_1 - 1]]) + T.writes([pad_temp[i0_1, i1_1, i2_1, i3_1]]) + pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + T.reads([pad_temp[v0, v1, v2, v3]]) + T.writes([pad_temp_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + T.reads([W[v0, v1, v2, v3]]) + T.writes([W_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + T.reads([compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy + ry, xx + rx], W_shared[ff, rc, ry, rx]]) + T.writes([compute_local[nn, ff, yy, xx]]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + T.block_attr({"meta_schedule.cache_type":1}) + T.reads([compute_local[v0, v1, v2, v3]]) + T.writes([compute_1[v0, v1, v2, v3]]) + compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] + for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): + with T.block("compute_1"): + i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads([compute_1[i0_2, i1_2, i2_2, i3_2], B[i1_2, 0, 0], bn_scale[i1_2, 0, 0], bn_offset[i1_2, 0, 0]]) + T.writes([compute[i0_2, i1_2, i2_2, i3_2]]) + compute[i0_2, i1_2, i2_2, i3_2] = T.max((compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class PaddingAndEpilogueInlined: + @T.prim_func + def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: + X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") + W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") + B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") + bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") + bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") + compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") + compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") + pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") + W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(0, 2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): + for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 40960, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): + with T.block("pad_temp_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) + v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) + v3 = T.axis.spatial(58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) + T.reads([X[v0, v1, v2 - 1, v3 - 1]]) + T.writes([pad_temp_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + pad_temp_shared[v0, v1, v2, v3] = T.if_then_else(v2 >= 1 and v2 < 57 and v3 >= 1 and v3 < 57, X[v0, v1, v2 - 1, v3 - 1], T.float32(0), dtype="float32") + for ax0_ax1_ax2_ax3_fused_0 in T.serial(0, 12288, annotations={"meta_schedule.cooperative_fetch":1}): + for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): + with T.block("W_shared"): + v0 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) + v1 = T.axis.spatial(512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) + v2 = T.axis.spatial(3, i5_0) + v3 = T.axis.spatial(3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) + T.reads([W[v0, v1, v2, v3]]) + T.writes([W_shared[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":0}) + W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] + for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): + with T.block("compute"): + nn = T.axis.spatial(1, 0) + ff = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) + yy = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) + xx = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) + rc = T.axis.reduce(512, i4_1 * 16 + i4_2) + ry, rx = T.axis.remap("RR", [i5_0, i6_2]) + T.reads([compute_local[nn, ff, yy, xx], pad_temp_shared[nn, rc, yy + ry, xx + rx], W_shared[ff, rc, ry, rx]]) + T.writes([compute_local[nn, ff, yy, xx]]) + with T.init(): + compute_local[nn, ff, yy, xx] = T.float32(0) + compute_local[nn, ff, yy, xx] = compute_local[nn, ff, yy, xx] + pad_temp_shared[nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] + for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): + with T.block("compute_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) + v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) + v3 = T.axis.spatial(56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) + T.reads([compute_local[v0, v1, v2, v3], B[v1, 0, 0], bn_scale[v1, 0, 0], bn_offset[v1, 0, 0]]) + T.writes([compute[v0, v1, v2, v3]]) + T.block_attr({"meta_schedule.cache_type":1}) + compute[v0, v1, v2, v3] = T.max((compute_local[v0, v1, v2, v3] + B[v1, 0, 0]) * bn_scale[v1, 0, 0] + bn_offset[v1, 0, 0], T.float32(0)) + + +@tvm.script.ir_module +class SoftmaxBeforeInline: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_exp"): + i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) + T_softmax_exp[i0_2, i1_1] = T.exp(A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32") + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_4, k = T.axis.remap("SR", [i0_3, i1]) + with T.init(): + T_softmax_expsum[i0_4] = T.float32(0) + T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] + for i0_5, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) + T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] + + +@tvm.script.ir_module +class SoftmaxAfterInline: + @T.prim_func + def main(A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32") + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T_softmax_norm[i0_4, i1_1] = T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4] + + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_inline_consumer_chain(): + mod = Conv2DBiasBnReLU + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) + + +def test_inline_into_cache(): + mod = NeedsInlinePaddingAndEpilogue + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=PaddingAndEpilogueInlined) + + +def test_inline_into_multiple_consumers(): + mod = SoftmaxBeforeInline + target = Target("cuda", host="llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=auto_inline(target=target), + ) + (space,) = ctx.space_generator.generate_design_space(mod=mod) + tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline) + + +if __name__ == "__main__": + test_inline_consumer_chain() + test_inline_into_cache() + test_inline_into_multiple_consumers() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py new file mode 100644 index 000000000000..47f405842c98 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import cross_thread_reduction +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.target import Target +from tvm.te.operation import create_prim_func + +import tvm +from tvm.script import tir as T + + +@tvm.script.ir_module +class Softmax_mn_after_inline: + @T.prim_func + def main( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_4, i1_1] = ( + T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") + / T_softmax_expsum[i0_4] + ) + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_gpu_softmax_mn(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3 = sch.get_loops(block=b1)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9 = sch.get_loops(block=b0)", + "l10, l11 = sch.split(loop=l9, factors=[None, v4])", + 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3 = sch.get_loops(block=b1)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9 = sch.get_loops(block=b0)", + "l10, l11 = sch.split(loop=l9, factors=[None, v4])", + 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b2, = sch.get_consumers(block=b1)", + "l3, l4 = sch.get_loops(block=b2)", + "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l6, l7 = sch.split(loop=l4, factors=[None, v5])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l8, l9, l10 = sch.get_loops(block=b1)", + "l11, l12 = sch.split(loop=l10, factors=[None, v5])", + 'sch.bind(loop=l12, thread_axis="threadIdx.x")', + "b13, = sch.get_consumers(block=b0)", + "l14, l15 = sch.get_loops(block=b13)", + "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l17, l18 = sch.split(loop=l15, factors=[None, v16])", + 'sch.bind(loop=l18, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l19, l20, l21 = sch.get_loops(block=b0)", + "l22, l23 = sch.split(loop=l21, factors=[None, v16])", + 'sch.bind(loop=l23, thread_axis="threadIdx.x")', + ], + ] + target = Target("nvidia/geforce-rtx-3090", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.softmax_mn( + n=256, + m=256, + ) + ), + target=target, + rule=cross_thread_reduction(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + +def test_gpu_softmax_mn_after_inline(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l2, l3 = sch.get_loops(block=b0)", + "l4, l5 = sch.split(loop=l3, factors=[None, v1])", + 'sch.bind(loop=l5, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3 = sch.get_loops(block=b1)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9 = sch.get_loops(block=b0)", + "l10, l11 = sch.split(loop=l9, factors=[None, v4])", + 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b2, = sch.get_consumers(block=b1)", + "l3, l4 = sch.get_loops(block=b2)", + "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l6, l7 = sch.split(loop=l4, factors=[None, v5])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l8, l9, l10 = sch.get_loops(block=b1)", + "l11, l12 = sch.split(loop=l10, factors=[None, v5])", + 'sch.bind(loop=l12, thread_axis="threadIdx.x")', + "b13, b14 = sch.get_consumers(block=b0)", + "l15, l16, l17, l18 = sch.get_loops(block=b13)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l19, l20, l21 = sch.get_loops(block=b0)", + "l22, l23 = sch.split(loop=l21, factors=[None, v5])", + 'sch.bind(loop=l23, thread_axis="threadIdx.x")', + ], + ] + target = Target("nvidia/geforce-rtx-3090", host="llvm") + ctx = _create_context( + mod=Softmax_mn_after_inline, + target=target, + rule=cross_thread_reduction(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + +def test_gpu_batch_norm_bmn(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, = sch.get_loops(block=b1)", + "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l4, l5 = sch.split(loop=l2, factors=[None, v3])", + 'sch.bind(loop=l5, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l6, l7, l8, l9 = sch.get_loops(block=b0)", + "l10 = sch.fuse(l8, l9)", + "l11, l12 = sch.split(loop=l10, factors=[None, v3])", + 'sch.bind(loop=l12, thread_axis="threadIdx.x")', + ], + ] + target = Target("nvidia/geforce-rtx-3090", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.norm_bmn( + B=1, + M=512, + N=512, + ) + ), + target=target, + rule=cross_thread_reduction(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 2 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_gpu_softmax_mn() + test_gpu_softmax_mn_after_inline() + test_gpu_batch_norm_bmn() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py new file mode 100644 index 000000000000..c2ad9258f275 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -0,0 +1,428 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import ( + multi_level_tiling, + multi_level_tiling_tensor_core, +) +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.te import create_prim_func +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target +from tvm.meta_schedule.testing import tir_tensor_intrin + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cpu_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + # pylint: enable=line-too-long + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cuda_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", + ] + ] + # pylint: enable=line-too-long + target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", + "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", + "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="blockIdx.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="vthread.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="threadIdx.x")', + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_tensor_core_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "l4, l5 = sch.split(loop=l1, factors=[32, 16])", + "l6, l7 = sch.split(loop=l2, factors=[32, 16])", + "l8, l9 = sch.split(loop=l3, factors=[32, 16])", + "l10, l11, l12, l13, l14, l15 = sch.get_loops(block=b0)", + "sch.reorder(l12, l14, l5, l7, l9)", + "b16 = sch.blockize(loop=l5)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync")', + 'sch.annotate(block_or_loop=b16, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill")', + 'b17 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b17, ann_key="meta_schedule.tensor_core_enabled", ann_val="1")', + 'b18 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="local")', + 'b19 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="wmma.accumulator")', + 'sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store")', + "l20, l21, l22 = sch.get_loops(block=b16)", + "v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=64)", + "l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27])", + "v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=64)", + "l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37])", + "v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=64)", + "l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45])", + "sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)", + "l49 = sch.fuse(l28, l38)", + 'sch.bind(loop=l49, thread_axis="blockIdx.x")', + "l50 = sch.fuse(l29, l39)", + 'sch.bind(loop=l50, thread_axis="blockIdx.y")', + "l51 = sch.fuse(l30, l40)", + 'sch.bind(loop=l51, thread_axis="threadIdx.y")', + 'b52 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)", + "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)", + "l59 = sch.fuse(l57, l58)", + "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)', + 'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)", + "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)", + "l68 = sch.fuse(l66, l67)", + "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)', + 'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', + 'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', + "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)", + "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', + "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)", + ] + ] + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_fp16( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling_tensor_core(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_tensor_core_matmul_relu(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "l4, l5 = sch.split(loop=l1, factors=[32, 16])", + "l6, l7 = sch.split(loop=l2, factors=[32, 16])", + "l8, l9 = sch.split(loop=l3, factors=[32, 16])", + "l10, l11, l12, l13, l14, l15 = sch.get_loops(block=b0)", + "sch.reorder(l12, l14, l5, l7, l9)", + "b16 = sch.blockize(loop=l5)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync")', + 'sch.annotate(block_or_loop=b16, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_fill")', + 'b17 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b17, ann_key="meta_schedule.tensor_core_enabled", ann_val="1")', + 'b18 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="local")', + 'b19 = sch.cache_write(block=b16, write_buffer_index=0, storage_scope="wmma.accumulator")', + 'sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store")', + "l20, l21, l22 = sch.get_loops(block=b16)", + "v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=64)", + "l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27])", + "v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=64)", + "l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37])", + "v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=64)", + "l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45])", + "sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42)", + "l49 = sch.fuse(l28, l38)", + 'sch.bind(loop=l49, thread_axis="blockIdx.x")', + "l50 = sch.fuse(l29, l39)", + 'sch.bind(loop=l50, thread_axis="blockIdx.y")', + "l51 = sch.fuse(l30, l40)", + 'sch.bind(loop=l51, thread_axis="threadIdx.y")', + 'b52 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)", + "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)", + "l59 = sch.fuse(l57, l58)", + "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)', + 'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)", + "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)", + "l68 = sch.fuse(l66, l67)", + "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)', + 'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', + 'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', + "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)", + "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', + "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)", + "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)", + ] + ] + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu_fp16( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling_tensor_core(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_matmul_relu() + test_cuda_matmul() + test_cuda_matmul_relu() + test_cuda_tensor_core_matmul() + test_cuda_tensor_core_matmul_relu() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..e57799f604b8 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class ParallelizeVectorizeUnroll: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + T.reads([]) + T.writes([]) + T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2}) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_parallel_vectorize_unroll(): + expected = [ + [ + 'b0 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)', + "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', + ] + ] + mod = Matmul + target = Target("llvm --num-cores=32") + ctx = _create_context( + mod=mod, + target=target, + rule=parallel_vectorize_unroll(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_parallel_vectorize_unroll() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py new file mode 100644 index 000000000000..b4d1964d3775 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import random_compute_location +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T +from tvm.target import Target + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks + +@tvm.script.ir_module +class Add: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [2048, 2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048, 2048], dtype="float32") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + # body + for i, j, k in T.grid(2048, 2048, 2048): + with T.block("move"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + T.reads([A[vi, vj, vk]]) + T.writes([A_cached[vi, vj, vk]]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads([A_cached[vi, vj, vk]]) + T.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + +# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks +# fmt: on + + +def _create_context(mod, target, rule): + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_random_compute_location(): + expected = [ + [ + 'b0 = sch.get_block(name="move", func_name="main")', + "l1 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True)", + ] + ] + mod = Add + target = Target("llvm") + ctx = _create_context( + mod=mod, + target=target, + rule=random_compute_location(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_random_compute_location() diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 9b3ddfd7c789..d0cdc3571bd7 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -16,25 +16,35 @@ # under the License. """ Test Meta Schedule SearchStrategy """ # pylint: disable=missing-function-docstring -from typing import List - import sys +from typing import List, Optional, Tuple, Union +import numpy as np import pytest - import tvm +from tvm.ir import IRModule from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.builder import LocalBuilder +from tvm.meta_schedule.cost_model import PyCostModel +from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload +from tvm.meta_schedule.mutator.mutator import PyMutator +from tvm.meta_schedule.runner import LocalRunner, RunnerResult +from tvm.meta_schedule.search_strategy import ( + EvolutionarySearch, + MeasureCandidate, + ReplayFunc, + ReplayTrace, + SearchStrategy, +) from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace - +from tvm.meta_schedule.task_scheduler import RoundRobin from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace MATMUL_M = 32 -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking +# pylint: disable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, unbalanced-tuple-unpacking # fmt: off @tvm.script.ir_module @@ -53,48 +63,209 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] # fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# pylint: enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: - trace_1 = Trace(sch_1.trace.insts, {}) - trace_2 = Trace(sch_2.trace.insts, {}) +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: + if remove_decisions: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + else: + trace_1 = sch_1.trace + trace_2 = sch_2.trace return str(trace_1) == str(trace_2) def _schedule_matmul(sch: Schedule): block = sch.get_block("matmul") i, j, k = sch.get_loops(block=block) - # TODO(@zxybazh): Change to `sample_perfect_tile` after upstreaming - i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) - j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) - k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + i_0, i_1, i_2, i_3 = sch.split(i, sch.sample_perfect_tile(i, n=4)) + j_0, j_1, j_2, j_3 = sch.split(j, sch.sample_perfect_tile(j, n=4)) + k_0, k_1 = sch.split(k, sch.sample_perfect_tile(k, n=2)) sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -def test_meta_schedule_replay_trace(): +@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) +def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disable = invalid-name num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul) - replay.initialize_with_tune_context(tune_context) - - num_trials_each_round: List[int] = [] - replay.pre_tuning([example_sch]) - while True: - candidates = replay.generate_measure_candidates() - if candidates is None: - break - num_trials_each_round.append(len(candidates)) + strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) + runner_results: List[RunnerResult] = [] + for candidate in candidates: + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(tune_context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [7, 7, 6] + + +def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name + class DummyMutator(PyMutator): + """Dummy Mutator for testing""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + class DummyDatabase(PyDatabase): + """Dummy Database for testing""" + + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> bool: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + class RandomModel(PyCostModel): + """Random cost model for testing""" + + random_state: Union[Tuple[str, np.ndarray, int, int, float], dict] + path: Optional[str] + + def __init__( + self, + *, + seed: Optional[int] = None, + path: Optional[str] = None, + max_range: Optional[int] = 100, + ): + super().__init__() + if path is not None: + self.load(path) + else: + np.random.seed(seed) + self.random_state = np.random.get_state() + self.max_range = max_range + + def load(self, path: str) -> None: + self.random_state = tuple(np.load(path, allow_pickle=True)) + + def save(self, path: str) -> None: + np.save(path, np.array(self.random_state, dtype=object), allow_pickle=True) + + def update( + self, + tune_context: TuneContext, + candidates: List[MeasureCandidate], + results: List[RunnerResult], + ) -> None: + pass + + def predict( + self, tune_context: TuneContext, candidates: List[MeasureCandidate] + ) -> np.ndarray: + np.random.set_state(self.random_state) + result = np.random.rand(len(candidates)) * self.max_range + self.random_state = np.random.get_state() + return result + + num_trials_per_iter = 10 + num_trials_total = 100 + + strategy = EvolutionarySearch( + num_trials_per_iter=num_trials_per_iter, + num_trials_total=num_trials_total, + population_size=5, + init_measured_ratio=0.1, + init_min_unmeasured=50, + genetic_num_iters=3, + genetic_mutate_prob=0.5, + genetic_max_fail_count=10, + eps_greedy=0.9, + ) + tune_context = TuneContext( + mod=Matmul, + space_generator=ScheduleFn(sch_fn=_schedule_matmul), + mutator_probs={ + DummyMutator(): 1.0, + }, + target=tvm.target.Target("llvm"), + num_threads=1, # because we are using a mutator from the python side + ) + _scheduler = RoundRobin( + tasks=[tune_context], + builder=LocalBuilder(), + runner=LocalRunner(), + database=DummyDatabase(), + cost_model=RandomModel(), + measure_callbacks=[], + ) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: - assert _is_trace_equal(candidate.sch, example_sch) - runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) - replay.notify_runner_results(runner_results) - replay.post_tuning() - assert num_trials_each_round == [7, 7, 6] + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(isinstance(strategy, ReplayTrace)), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(tune_context, candidates, runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + print(num_trials_each_iter) + correct_count = 10 # For each iteration except the last one + assert num_trials_each_iter == [correct_count] * (num_trials_total // correct_count) + ( + [num_trials_total % correct_count] if num_trials_total % correct_count != 0 else [] + ) + del _scheduler if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_sketch_cpu.py b/tests/python/unittest/test_meta_schedule_sketch_cpu.py new file mode 100644 index 000000000000..d0b20a3dd104 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cpu.py @@ -0,0 +1,795 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from typing import List + +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("llvm --num-cores=16") + + +def test_meta_schedule_cpu_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l18, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l19, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", + "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", + "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", + "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v25 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v25)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b2, = sch.get_consumers(block=b0)", + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l18, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b2, = sch.get_consumers(block=b0)", + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9])", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17])", + "v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l24, l25 = sch.split(loop=l5, factors=[v22, v23])", + "sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)", + "sch.reverse_compute_at(block=b2, loop=l19, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l3, l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b1)", + "v10, v11, v12, v13 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", + "l14, l15, l16, l17 = sch.split(loop=l3, factors=[v10, v11, v12, v13])", + "v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l22, l23, l24, l25 = sch.split(loop=l4, factors=[v18, v19, v20, v21])", + "v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l30, l31, l32, l33 = sch.split(loop=l5, factors=[v26, v27, v28, v29])", + "v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l38, l39, l40, l41 = sch.split(loop=l6, factors=[v34, v35, v36, v37])", + "v42, v43 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64)", + "l44, l45 = sch.split(loop=l7, factors=[v42, v43])", + "v46, v47 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l8, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l9, factors=[v50, v51])", + "sch.reorder(l14, l22, l30, l38, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=32)', + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', + "l55 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l55, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="global")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])", + "v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l23, l24, l25, l26 = sch.split(loop=l5, factors=[v19, v20, v21, v22])", + "v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l31, l32, l33, l34 = sch.split(loop=l6, factors=[v27, v28, v29, v30])", + "v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l39, l40, l41, l42 = sch.split(loop=l7, factors=[v35, v36, v37, v38])", + "v43, v44 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l45, l46 = sch.split(loop=l8, factors=[v43, v44])", + "v47, v48 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l49, l50 = sch.split(loop=l9, factors=[v47, v48])", + "v51, v52 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l53, l54 = sch.split(loop=l10, factors=[v51, v52])", + "sch.reorder(l15, l23, l31, l39, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42)", + "sch.reverse_compute_at(block=b3, loop=l39, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=32)', + "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v55)', + "l56 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l56, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="global")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l4, n=4, max_innermost_factor=64)", + "l15, l16, l17, l18 = sch.split(loop=l4, factors=[v11, v12, v13, v14])", + "v19, v20, v21, v22 = sch.sample_perfect_tile(loop=l5, n=4, max_innermost_factor=64)", + "l23, l24, l25, l26 = sch.split(loop=l5, factors=[v19, v20, v21, v22])", + "v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l31, l32, l33, l34 = sch.split(loop=l6, factors=[v27, v28, v29, v30])", + "v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l39, l40, l41, l42 = sch.split(loop=l7, factors=[v35, v36, v37, v38])", + "v43, v44 = sch.sample_perfect_tile(loop=l8, n=2, max_innermost_factor=64)", + "l45, l46 = sch.split(loop=l8, factors=[v43, v44])", + "v47, v48 = sch.sample_perfect_tile(loop=l9, n=2, max_innermost_factor=64)", + "l49, l50 = sch.split(loop=l9, factors=[v47, v48])", + "v51, v52 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l53, l54 = sch.split(loop=l10, factors=[v51, v52])", + "sch.reorder(l15, l23, l31, l39, l16, l24, l32, l40, l45, l49, l53, l17, l25, l33, l41, l46, l50, l54, l18, l26, l34, l42)", + "sch.reverse_compute_at(block=b3, loop=l40, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=32)', + "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v55)', + "l56 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l56, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l6, l7, l8, l9, l10, l11, l12 = sch.get_loops(block=b1)", + "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l6, n=4, max_innermost_factor=64)", + "l17, l18, l19, l20 = sch.split(loop=l6, factors=[v13, v14, v15, v16])", + "v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l25, l26, l27, l28 = sch.split(loop=l7, factors=[v21, v22, v23, v24])", + "v29, v30, v31, v32 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l33, l34, l35, l36 = sch.split(loop=l8, factors=[v29, v30, v31, v32])", + "v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l41, l42, l43, l44 = sch.split(loop=l9, factors=[v37, v38, v39, v40])", + "v45, v46 = sch.sample_perfect_tile(loop=l10, n=2, max_innermost_factor=64)", + "l47, l48 = sch.split(loop=l10, factors=[v45, v46])", + "v49, v50 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l51, l52 = sch.split(loop=l11, factors=[v49, v50])", + "v53, v54 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l55, l56 = sch.split(loop=l12, factors=[v53, v54])", + "sch.reorder(l17, l25, l33, l41, l18, l26, l34, l42, l47, l51, l55, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44)", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.vectorize", ann_val=32)', + "v57 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v57)', + "l58 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l58, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b6, = sch.get_consumers(block=b1)", + "l7, l8, l9, l10, l11, l12, l13 = sch.get_loops(block=b1)", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l7, factors=[v14, v15, v16, v17])", + "v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l26, l27, l28, l29 = sch.split(loop=l8, factors=[v22, v23, v24, v25])", + "v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l34, l35, l36, l37 = sch.split(loop=l9, factors=[v30, v31, v32, v33])", + "v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l10, n=4, max_innermost_factor=64)", + "l42, l43, l44, l45 = sch.split(loop=l10, factors=[v38, v39, v40, v41])", + "v46, v47 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l11, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l12, factors=[v50, v51])", + "v54, v55 = sch.sample_perfect_tile(loop=l13, n=2, max_innermost_factor=64)", + "l56, l57 = sch.split(loop=l13, factors=[v54, v55])", + "sch.reorder(l18, l26, l34, l42, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44, l49, l53, l57, l21, l29, l37, l45)", + "sch.reverse_compute_at(block=b6, loop=l42, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.vectorize", ann_val=32)', + "v58 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v58)', + "l59 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l59, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b4)", + "sch.compute_inline(block=b3)", + "sch.compute_inline(block=b2)", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "b6, = sch.get_consumers(block=b1)", + "l7, l8, l9, l10, l11, l12, l13 = sch.get_loops(block=b1)", + "v14, v15, v16, v17 = sch.sample_perfect_tile(loop=l7, n=4, max_innermost_factor=64)", + "l18, l19, l20, l21 = sch.split(loop=l7, factors=[v14, v15, v16, v17])", + "v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l8, n=4, max_innermost_factor=64)", + "l26, l27, l28, l29 = sch.split(loop=l8, factors=[v22, v23, v24, v25])", + "v30, v31, v32, v33 = sch.sample_perfect_tile(loop=l9, n=4, max_innermost_factor=64)", + "l34, l35, l36, l37 = sch.split(loop=l9, factors=[v30, v31, v32, v33])", + "v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l10, n=4, max_innermost_factor=64)", + "l42, l43, l44, l45 = sch.split(loop=l10, factors=[v38, v39, v40, v41])", + "v46, v47 = sch.sample_perfect_tile(loop=l11, n=2, max_innermost_factor=64)", + "l48, l49 = sch.split(loop=l11, factors=[v46, v47])", + "v50, v51 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l52, l53 = sch.split(loop=l12, factors=[v50, v51])", + "v54, v55 = sch.sample_perfect_tile(loop=l13, n=2, max_innermost_factor=64)", + "l56, l57 = sch.split(loop=l13, factors=[v54, v55])", + "sch.reorder(l18, l26, l34, l42, l19, l27, l35, l43, l48, l52, l56, l20, l28, l36, l44, l49, l53, l57, l21, l29, l37, l45)", + "sch.reverse_compute_at(block=b6, loop=l43, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.vectorize", ann_val=32)', + "v58 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b5, ann_key="meta_schedule.unroll_explicit", ann_val=v58)', + "l59 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l59, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_sketch_cpu_max_pool2d_nchw(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected: List[List[str]] = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + "l3 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l3, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.max_pool2d_nchw( + n=1, + h=56, + w=56, + ci=512, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_batchnorm(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "l5 = sch.fuse(l3, l4)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "l2, l3, l4 = sch.get_loops(block=b0)", + "l5 = sch.fuse(l3, l4)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=32)', + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + "l3 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l3, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func(te_workload.norm_bmn(B=1, M=256, N=256)), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_meta_schedule_cpu_sketch_softmax(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l15, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l16, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + "l15 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l15, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + "l11, l12 = sch.get_loops(block=b0)", + "v13, v14 = sch.sample_perfect_tile(loop=l12, n=2, max_innermost_factor=64)", + "l15, l16 = sch.split(loop=l12, factors=[v13, v14])", + "b17 = sch.rfactor(loop=l16, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v18 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v18)', + "b19, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l20 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l20, preserve_unit_loops=True)", + "l21 = sch.sample_compute_location(block=b19)", + "sch.compute_at(block=b19, loop=l21, preserve_unit_loops=True)", + "l22 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l22, preserve_unit_loops=True)", + "b23, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l24 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l24, preserve_unit_loops=True)", + "l25 = sch.sample_compute_location(block=b23)", + "sch.compute_at(block=b23, loop=l25, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b2)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "b12, = sch.get_producers(block=b2)", + 'sch.unannotate(block_or_loop=b2, ann_key="meta_schedule.random_compute_producer")', + "l13 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l13, preserve_unit_loops=True)", + "l14 = sch.sample_compute_location(block=b12)", + "sch.compute_at(block=b12, loop=l14, preserve_unit_loops=True)", + "l15 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b0)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l8, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "l12 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l12, preserve_unit_loops=True)", + "l13 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l13, preserve_unit_loops=True)", + "b14, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l15 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b14)", + "sch.compute_at(block=b14, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "l4, l5 = sch.get_loops(block=b0)", + "v6, v7 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64)", + "l8, l9 = sch.split(loop=l5, factors=[v6, v7])", + "b10 = sch.rfactor(loop=l9, factor_axis=1)", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v11 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v11)', + "l12 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l12, preserve_unit_loops=True)", + "l13 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l13, preserve_unit_loops=True)", + "b14, = sch.get_producers(block=b0)", + 'sch.unannotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer")', + "l15 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True)", + "l16 = sch.sample_compute_location(block=b14)", + "sch.compute_at(block=b14, loop=l16, preserve_unit_loops=True)", + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.parallel", ann_val=256)', + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.vectorize", ann_val=32)', + "v4 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v4)', + "l5 = sch.sample_compute_location(block=b2)", + "sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)", + "l6 = sch.sample_compute_location(block=b1)", + "sch.compute_at(block=b1, loop=l6, preserve_unit_loops=True)", + "l7 = sch.sample_compute_location(block=b0)", + "sch.compute_at(block=b0, loop=l7, preserve_unit_loops=True)", + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func(te_workload.softmax_mn(m=256, n=256)), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 9 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cpu_sketch_matmul() + test_meta_schedule_cpu_sketch_matmul_relu() + test_meta_schedule_cpu_sketch_conv2d_nchw() + test_meta_schedule_cpu_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_sketch_cpu_max_pool2d_nchw() + test_meta_schedule_cpu_sketch_batchnorm() + test_meta_schedule_cpu_sketch_softmax() diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py new file mode 100644 index 000000000000..3255c958a575 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -0,0 +1,426 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import check_trace, create_context +from tvm.target import Target +from tvm.te import create_prim_func + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _target_with_max_threads_per_block() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def test_meta_schedule_cuda_sketch_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b2 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l3, l4, l5 = sch.get_loops(block=b0)", + "v6, v7, v8, v9, v10 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", + "l11, l12, l13, l14, l15 = sch.split(loop=l3, factors=[v6, v7, v8, v9, v10])", + "v16, v17, v18, v19, v20 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l21, l22, l23, l24, l25 = sch.split(loop=l4, factors=[v16, v17, v18, v19, v20])", + "v26, v27, v28 = sch.sample_perfect_tile(loop=l5, n=3, max_innermost_factor=64)", + "l29, l30, l31 = sch.split(loop=l5, factors=[v26, v27, v28])", + "sch.reorder(l11, l21, l12, l22, l13, l23, l29, l30, l14, l24, l31, l15, l25)", + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="blockIdx.x")', + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="vthread.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="threadIdx.x")', + 'b35 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b35, loop=l29, preserve_unit_loops=True)", + "l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35)", + "l42 = sch.fuse(l40, l41)", + "v43 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l29, preserve_unit_loops=True)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52)', + "sch.reverse_compute_at(block=b2, loop=l34, preserve_unit_loops=True)", + "v53 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b3 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "l4, l5, l6 = sch.get_loops(block=b0)", + "v7, v8, v9, v10, v11 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l12, l13, l14, l15, l16 = sch.split(loop=l4, factors=[v7, v8, v9, v10, v11])", + "v17, v18, v19, v20, v21 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l22, l23, l24, l25, l26 = sch.split(loop=l5, factors=[v17, v18, v19, v20, v21])", + "v27, v28, v29 = sch.sample_perfect_tile(loop=l6, n=3, max_innermost_factor=64)", + "l30, l31, l32 = sch.split(loop=l6, factors=[v27, v28, v29])", + "sch.reorder(l12, l22, l13, l23, l14, l24, l30, l31, l15, l25, l32, l16, l26)", + "l33 = sch.fuse(l12, l22)", + 'sch.bind(loop=l33, thread_axis="blockIdx.x")', + "l34 = sch.fuse(l13, l23)", + 'sch.bind(loop=l34, thread_axis="vthread.x")', + "l35 = sch.fuse(l14, l24)", + 'sch.bind(loop=l35, thread_axis="threadIdx.x")', + 'b36 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b36, loop=l30, preserve_unit_loops=True)", + "l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b36)", + "l43 = sch.fuse(l41, l42)", + "v44 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)', + 'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b45, loop=l30, preserve_unit_loops=True)", + "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)", + "l52 = sch.fuse(l50, l51)", + "v53 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', + "sch.reverse_compute_at(block=b3, loop=l35, preserve_unit_loops=True)", + "sch.reverse_compute_inline(block=b1)", + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b3 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b1)", + "v11, v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l4, n=5, max_innermost_factor=64)", + "l16, l17, l18, l19, l20 = sch.split(loop=l4, factors=[v11, v12, v13, v14, v15])", + "v21, v22, v23, v24, v25 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64)", + "l26, l27, l28, l29, l30 = sch.split(loop=l5, factors=[v21, v22, v23, v24, v25])", + "v31, v32, v33, v34, v35 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64)", + "l36, l37, l38, l39, l40 = sch.split(loop=l6, factors=[v31, v32, v33, v34, v35])", + "v41, v42, v43, v44, v45 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64)", + "l46, l47, l48, l49, l50 = sch.split(loop=l7, factors=[v41, v42, v43, v44, v45])", + "v51, v52, v53 = sch.sample_perfect_tile(loop=l8, n=3, max_innermost_factor=64)", + "l54, l55, l56 = sch.split(loop=l8, factors=[v51, v52, v53])", + "v57, v58, v59 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64)", + "l60, l61, l62 = sch.split(loop=l9, factors=[v57, v58, v59])", + "v63, v64, v65 = sch.sample_perfect_tile(loop=l10, n=3, max_innermost_factor=64)", + "l66, l67, l68 = sch.split(loop=l10, factors=[v63, v64, v65])", + "sch.reorder(l16, l26, l36, l46, l17, l27, l37, l47, l18, l28, l38, l48, l54, l60, l66, l55, l61, l67, l19, l29, l39, l49, l56, l62, l68, l20, l30, l40, l50)", + "l69 = sch.fuse(l16, l26, l36, l46)", + 'sch.bind(loop=l69, thread_axis="blockIdx.x")', + "l70 = sch.fuse(l17, l27, l37, l47)", + 'sch.bind(loop=l70, thread_axis="vthread.x")', + "l71 = sch.fuse(l18, l28, l38, l48)", + 'sch.bind(loop=l71, thread_axis="threadIdx.x")', + 'b72 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b72, loop=l66, preserve_unit_loops=True)", + "l73, l74, l75, l76, l77, l78, l79, l80, l81, l82 = sch.get_loops(block=b72)", + "l83 = sch.fuse(l79, l80, l81, l82)", + "v84 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v84)', + 'b85 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b85, loop=l66, preserve_unit_loops=True)", + "l86, l87, l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b85)", + "l96 = sch.fuse(l92, l93, l94, l95)", + "v97 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v97)', + "sch.reverse_compute_at(block=b3, loop=l71, preserve_unit_loops=True)", + "sch.compute_inline(block=b0)", + "v98 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v98)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disable=invalid-name + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="pad_temp", func_name="main")', + 'b1 = sch.get_block(name="compute", func_name="main")', + 'b2 = sch.get_block(name="bias_add", func_name="main")', + 'b3 = sch.get_block(name="bn_mul", func_name="main")', + 'b4 = sch.get_block(name="bn_add", func_name="main")', + 'b5 = sch.get_block(name="compute_1", func_name="main")', + 'b6 = sch.get_block(name="root", func_name="main")', + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + 'b7 = sch.cache_write(block=b1, write_buffer_index=0, storage_scope="local")', + "l8, l9, l10, l11, l12, l13, l14 = sch.get_loops(block=b1)", + "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64)", + "l20, l21, l22, l23, l24 = sch.split(loop=l8, factors=[v15, v16, v17, v18, v19])", + "v25, v26, v27, v28, v29 = sch.sample_perfect_tile(loop=l9, n=5, max_innermost_factor=64)", + "l30, l31, l32, l33, l34 = sch.split(loop=l9, factors=[v25, v26, v27, v28, v29])", + "v35, v36, v37, v38, v39 = sch.sample_perfect_tile(loop=l10, n=5, max_innermost_factor=64)", + "l40, l41, l42, l43, l44 = sch.split(loop=l10, factors=[v35, v36, v37, v38, v39])", + "v45, v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l11, n=5, max_innermost_factor=64)", + "l50, l51, l52, l53, l54 = sch.split(loop=l11, factors=[v45, v46, v47, v48, v49])", + "v55, v56, v57 = sch.sample_perfect_tile(loop=l12, n=3, max_innermost_factor=64)", + "l58, l59, l60 = sch.split(loop=l12, factors=[v55, v56, v57])", + "v61, v62, v63 = sch.sample_perfect_tile(loop=l13, n=3, max_innermost_factor=64)", + "l64, l65, l66 = sch.split(loop=l13, factors=[v61, v62, v63])", + "v67, v68, v69 = sch.sample_perfect_tile(loop=l14, n=3, max_innermost_factor=64)", + "l70, l71, l72 = sch.split(loop=l14, factors=[v67, v68, v69])", + "sch.reorder(l20, l30, l40, l50, l21, l31, l41, l51, l22, l32, l42, l52, l58, l64, l70, l59, l65, l71, l23, l33, l43, l53, l60, l66, l72, l24, l34, l44, l54)", + "l73 = sch.fuse(l20, l30, l40, l50)", + 'sch.bind(loop=l73, thread_axis="blockIdx.x")', + "l74 = sch.fuse(l21, l31, l41, l51)", + 'sch.bind(loop=l74, thread_axis="vthread.x")', + "l75 = sch.fuse(l22, l32, l42, l52)", + 'sch.bind(loop=l75, thread_axis="threadIdx.x")', + 'b76 = sch.cache_read(block=b1, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b76, loop=l70, preserve_unit_loops=True)", + "l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 = sch.get_loops(block=b76)", + "l87 = sch.fuse(l83, l84, l85, l86)", + "v88 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v88)', + 'b89 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b89, loop=l70, preserve_unit_loops=True)", + "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)", + "l100 = sch.fuse(l96, l97, l98, l99)", + "v101 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b89, ann_key="meta_schedule.cooperative_fetch", ann_val=v101)', + "sch.reverse_compute_at(block=b7, loop=l75, preserve_unit_loops=True)", + "sch.reverse_compute_inline(block=b5)", + "sch.reverse_compute_inline(block=b4)", + "sch.reverse_compute_inline(block=b3)", + "sch.reverse_compute_inline(block=b2)", + "sch.compute_inline(block=b0)", + "v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v102)', + ] + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.conv2d_nchw_bias_bn_relu( + n=1, + h=56, + w=56, + ci=512, + co=512, + kh=3, + kw=3, + stride=1, + padding=1, + ) + ), + target=_target(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_batchnorm(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "b2, = sch.get_consumers(block=b0)", + "l3, = sch.get_loops(block=b2)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l5, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9, l10 = sch.get_loops(block=b0)", + "l11 = sch.fuse(l9, l10)", + "l12, l13 = sch.split(loop=l11, factors=[None, v4])", + 'sch.bind(loop=l13, thread_axis="threadIdx.x")', + "v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v14)', + ], + [ + 'b0 = sch.get_block(name="root", func_name="main")', + "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.norm_bmn( + B=1, + M=256, + N=256, + ) + ), + target=_target_with_max_threads_per_block(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 2 + check_trace(spaces, expected) + + +def test_meta_schedule_cuda_sketch_softmax(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b3 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b1)", + "b4, = sch.get_consumers(block=b2)", + "l5, l6 = sch.get_loops(block=b4)", + "v7 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l8, l9 = sch.split(loop=l6, factors=[None, v7])", + 'sch.bind(loop=l9, thread_axis="threadIdx.x")', + "sch.compute_at(block=b2, loop=l5, preserve_unit_loops=True)", + 'sch.set_scope(block=b2, buffer_index=0, storage_scope="shared")', + "l10, l11, l12 = sch.get_loops(block=b2)", + "l13, l14 = sch.split(loop=l12, factors=[None, v7])", + 'sch.bind(loop=l14, thread_axis="threadIdx.x")', + "b15, b16 = sch.get_consumers(block=b0)", + "l17, l18, l19, l20 = sch.get_loops(block=b15)", + "sch.compute_at(block=b0, loop=l17, preserve_unit_loops=True)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l21, l22, l23 = sch.get_loops(block=b0)", + "l24, l25 = sch.split(loop=l23, factors=[None, v7])", + 'sch.bind(loop=l25, thread_axis="threadIdx.x")', + "v26 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v26)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b0)", + "b3, = sch.get_consumers(block=b1)", + "l4, l5 = sch.get_loops(block=b3)", + "v6 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l7, l8 = sch.split(loop=l5, factors=[None, v6])", + 'sch.bind(loop=l8, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l4, preserve_unit_loops=True)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l9, l10, l11 = sch.get_loops(block=b1)", + "l12, l13 = sch.split(loop=l11, factors=[None, v6])", + 'sch.bind(loop=l13, thread_axis="threadIdx.x")', + "v14 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v14)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b2 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b1)", + "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l4, l5 = sch.get_loops(block=b0)", + "l6, l7 = sch.split(loop=l5, factors=[None, v3])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "v8 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v8)', + ], + [ + 'b0 = sch.get_block(name="T_softmax_exp", func_name="main")', + 'b1 = sch.get_block(name="root", func_name="main")', + "sch.compute_inline(block=b0)", + "v2 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v2)', + ], + ] + # pylint: enable=line-too-long + ctx = create_context( + create_prim_func( + te_workload.softmax_mn( + m=256, + n=256, + ) + ), + target=_target_with_max_threads_per_block(), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_meta_schedule_cuda_sketch_matmul() + test_meta_schedule_cuda_sketch_matmul_relu() + test_meta_schedule_cuda_sketch_conv2d_nchw() + test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu() + test_meta_schedule_cuda_sketch_batchnorm() + test_meta_schedule_cuda_sketch_softmax() diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 49a3f6309183..3eb050db3baa 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,6 +23,9 @@ import pytest import tvm +from tvm._ffi.base import TVMError +from tvm.ir.module import IRModule +from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py new file mode 100644 index 000000000000..8523275f5186 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import sys +from typing import Tuple + +import pytest + +import tvm +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model + + +@pytest.mark.skip("Skip because it runs too slowly as a unittest") +@pytest.mark.parametrize( + "model_name", + [ + # Image classification + "resnet50", + "alexnet", + "vgg16", + "squeezenet1_0", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "inception_v3", + "googlenet", + "shufflenet_v2_x1_0", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "resnext50_32x4d", + "wide_resnet50_2", + "mnasnet1_0", + # Segmentation + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", + # Object detection + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "maskrcnn_resnet50_fpn", + # video classification + "r3d_18", + "mc3_18", + "r2plus1d_18", + ], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("target", ["llvm", "cuda"]) +def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int, target: str): + if model_name == "inception_v3" and batch_size == 1: + pytest.skip("inception_v3 does not handle batch_size of 1") + + input_shape: Tuple[int, ...] + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + + output_shape: Tuple[int, int] = (batch_size, 1000) + mod, params = get_torch_model( + model_name=model_name, + input_shape=input_shape, + output_shape=output_shape, + dtype="float32", + ) + target = tvm.target.Target(target) + ms.integration.extract_task_from_relay(mod, params=params, target=target) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index 7eb61ad2c7cf..d3c4dbca826f 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -16,24 +16,22 @@ # under the License. """ Test Meta Schedule Task Scheduler """ -from typing import List - -import sys import random +import sys +from typing import List import pytest - import tvm -from tvm.script import tir as T from tvm.ir import IRModule -from tvm.tir import Schedule -from tvm.meta_schedule import TuneContext -from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace -from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult -from tvm.meta_schedule.runner import PyRunner, RunnerInput, RunnerFuture, RunnerResult +from tvm.meta_schedule import TuneContext, measure_callback +from tvm.meta_schedule.builder import BuilderInput, BuilderResult, PyBuilder from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload -from tvm.meta_schedule.task_scheduler import RoundRobin, PyTaskScheduler +from tvm.meta_schedule.runner import PyRunner, RunnerFuture, RunnerInput, RunnerResult +from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.space_generator import ScheduleFn +from tvm.meta_schedule.task_scheduler import PyTaskScheduler, RoundRobin +from tvm.script import tir as T +from tvm.tir import Schedule # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -140,7 +138,10 @@ def __init__(self): self.records = [] self.workload_reg = [] - def has_workload(self, mod: IRModule) -> bool: + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True return False def commit_tuning_record(self, record: TuningRecord) -> None: @@ -183,7 +184,13 @@ def test_meta_schedule_task_scheduler_single(): rand_state=42, ) database = DummyDatabase() - round_robin = RoundRobin([task], DummyBuilder(), DummyRunner(), database) + round_robin = RoundRobin( + [task], + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + ) round_robin.tune() assert len(database) == num_trials_total @@ -218,15 +225,29 @@ def test_meta_schedule_task_scheduler_multiple(): ), ] database = DummyDatabase() - round_robin = RoundRobin(tasks, DummyBuilder(), DummyRunner(), database) + round_robin = RoundRobin( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[measure_callback.AddToDatabase()], + ) round_robin.tune() assert len(database) == num_trials_total * len(tasks) print(database.workload_reg) for task in tasks: - assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) -def test_meta_schedule_task_scheduler_NIE(): +def test_meta_schedule_task_scheduler_not_implemented_error(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): pass @@ -234,7 +255,7 @@ class MyTaskScheduler(PyTaskScheduler): MyTaskScheduler([], DummyBuilder(), DummyRunner(), DummyDatabase()) -def test_meta_schedule_task_scheduler_override_next_task_id_only(): +def test_meta_schedule_task_scheduler_override_next_task_id_only(): # pylint: disable=invalid-name class MyTaskScheduler(PyTaskScheduler): done = set() @@ -291,11 +312,27 @@ def next_task_id(self) -> int: ), ] database = DummyDatabase() - scheduler = MyTaskScheduler(tasks, DummyBuilder(), DummyRunner(), database) + scheduler = MyTaskScheduler( + tasks, + DummyBuilder(), + DummyRunner(), + database, + measure_callbacks=[ + measure_callback.AddToDatabase(), + ], + ) scheduler.tune() assert len(database) == num_trials_total * len(tasks) for task in tasks: - assert len(database.get_top_k(database.commit_workload(task.mod), 1e9)) == num_trials_total + assert ( + len( + database.get_top_k( + database.commit_workload(task.mod), + 100000, + ) + ) + == num_trials_total + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py new file mode 100644 index 000000000000..09aaa08d5185 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import logging +import tempfile +import pytest +import numpy as np +from typing import Tuple, List + +import tvm +from tvm import relay +from tvm.ir import IRModule +from tvm.runtime.ndarray import cpu, cuda +from tvm.target.target import Target +from tvm.contrib import graph_executor +from tvm.meta_schedule import ReplayTraceConfig +from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord +from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model +from tvm.meta_schedule.tune import tune_relay + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +class DummyDatabase(PyDatabase): + def __init__(self): + super().__init__() + self.records = [] + self.workload_reg = [] + + def has_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return True + return False + + def commit_tuning_record(self, record: TuningRecord) -> None: + self.records.append(record) + + def commit_workload(self, mod: IRModule) -> Workload: + for workload in self.workload_reg: + if tvm.ir.structural_equal(workload.mod, mod): + return workload + workload = Workload(mod) + self.workload_reg.append(workload) + return workload + + def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]: + return list( + filter( + lambda x: x.workload == workload, + sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)), + ) + )[: int(top_k)] + + def __len__(self) -> int: + return len(self.records) + + def print_results(self) -> None: + print("\n".join([str(r) for r in self.records])) + + +@pytest.mark.skip("Integration test") +@pytest.mark.parametrize("model_name", ["resnet18", "mobilenet_v2", "bert_base"]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("target", ["llvm --num-cores=16", "nvidia/geforce-rtx-3070"]) +def test_meta_schedule_tune_relay(model_name: str, batch_size: int, target: str): + if model_name == "inception_v3" and batch_size == 1: + pytest.skip("inception_v3 does not handle batch_size of 1") + + input_shape: Tuple[int, ...] + input_name = "input0" + dev = tvm.cpu() if str(target).startswith("llvm") else cuda() + if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION: + seq_length = 128 + input_name = "input_ids" + input_shape = (batch_size, seq_length) + data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size + else: + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) + + output_shape: Tuple[int, int] = (batch_size, 1000) + + mod, params = get_torch_model( + model_name=model_name, + input_shape=input_shape, + output_shape=output_shape, + dtype="float32", + ) + + with tempfile.TemporaryDirectory() as work_dir: + target = Target(target) + database = DummyDatabase() + rt_mod: tvm.module = tune_relay( + mod=mod, + params=params, + target=target, + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + work_dir=work_dir, + database=database, + ) + # Compile without meta-scheduler for correctness check + with tvm.transform.PassContext(opt_level=0): + rt_mod2 = relay.build(mod, target=target, params=params) + + def get_output(data, lib): + module = graph_executor.GraphModule(lib["default"](dev)) + module.set_input(input_name, data) + module.run() + return module.get_output(0).numpy() + + # Check correctness + actual_output = get_output(data, rt_mod) + expected_output = get_output(data, rt_mod2) + assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) + + +if __name__ == """__main__""": + # test_meta_schedule_tune_relay("resnet18", 1, "llvm --num-cores=16") + test_meta_schedule_tune_relay("resnet18", 1, "nvidia/geforce-rtx-3070") + # test_meta_schedule_tune_relay("mobilenet_v2", 1, "llvm --num-cores=16") + # test_meta_schedule_tune_relay("mobilenet_v2", 1, "nvidia/geforce-rtx-3070") + # test_meta_schedule_tune_relay("bert_base", 1, "llvm --num-cores=16") + # test_meta_schedule_tune_relay("bert_base", 1, "nvidia/geforce-rtx-3070") diff --git a/tests/python/unittest/test_meta_schedule_tune_te.py b/tests/python/unittest/test_meta_schedule_tune_te.py new file mode 100644 index 000000000000..a07bf1760346 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_tune_te.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import logging +import tempfile + +import pytest +from tvm.meta_schedule import ReplayTraceConfig, tune_te +from tvm.meta_schedule.testing import te_workload +from tvm.target.target import Target +from tvm.tir import Schedule + + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +@pytest.mark.skip("Integration test") +def test_tune_matmul(): + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_te( + tensors=te_workload.batch_matmul_nkkm(B=1, N=128, M=128, K=128), + target=Target("llvm --num-cores=16"), + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + work_dir=work_dir, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +if __name__ == """__main__""": + test_tune_matmul() diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py new file mode 100644 index 000000000000..6e80e5a69c11 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import logging +import tempfile + +import tvm +import pytest +from tvm.meta_schedule import ReplayTraceConfig, tune_tir +from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule import schedule_rule, postproc +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.script import tir as T +from tvm.target.target import Target +from tvm.te.operation import create_prim_func +from tvm.tir import Schedule +from tvm.meta_schedule.testing import te_workload, tir_tensor_intrin + +logging.basicConfig() +logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) + + +# pylint: disable=no-member,invalid-name,unused-variable + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +# pylint: enable=no-member,invalid-name,unused-variable + + +@pytest.mark.skip("Integration test") +def test_tune_matmul_cpu(): + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=matmul, + target=Target("llvm --num-cores=16"), + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + work_dir=work_dir, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +@pytest.mark.skip("Integration test") +def test_tune_matmul_cuda(): + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=matmul, + target=Target("nvidia/geforce-rtx-3070"), + config=ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=32, + ), + work_dir=work_dir, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + +@pytest.mark.skip("Integeration test") +def test_tune_matmul_cuda_tensor_core(): + n = 512 + mod = create_prim_func(te_workload.matmul_fp16(n, n, n)) + target = Target("nvidia/geforce-rtx-3070") + config = ReplayTraceConfig( + num_trials_per_iter=32, + num_trials_total=320, + ) + + class DefaultTensorCore: + @staticmethod + def _sch_rules(): + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + schedule_rule as M, + ) + + return [ + M.AutoInline( + into_producer=False, + into_consumer=True, + into_cache_only=False, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + M.MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], + use_tensor_core=True, + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=schedule_rule.ReuseType( + req="no", + levels=[], + scope="", + ), + ), + M.AutoInline( + into_producer=True, + into_consumer=True, + into_cache_only=True, + inline_const_tensor=True, + disallow_if_then_else=False, + require_injective=False, + require_ordered=False, + disallow_op=None, + ), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, # disable parallelize + max_vectorize_extent=-1, # disable vectorize + unroll_max_steps=[0, 16, 64, 512, 1024], + unroll_explicit=True, + ), + ] + + @staticmethod + def _postproc(): + from tvm.meta_schedule import ( # pylint: disable=import-outside-toplevel + postproc as M, + ) + + return [ + M.RewriteCooperativeFetch(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + M.RewriteTensorCore(), + M.VerifyGPUCode(), + ] + + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=mod, + target=target, + config=config, + work_dir=work_dir, + space=PostOrderApply(), + sch_rules=DefaultTensorCore._sch_rules, + postprocs=DefaultTensorCore._postproc, + num_threads=None, + ) + if sch is None: + print("No valid schedule found!") + else: + print(sch.mod.script()) + print(sch.trace) + + from tvm.contrib import nvcc + import numpy as np + + ctx = tvm.gpu(0) + if nvcc.have_tensorcore(ctx.compute_version): + with tvm.transform.PassContext(): + func = tvm.build(sch.mod["main"], [], "cuda") + print(sch.mod.script()) + print(func.imported_modules[0].get_source()) + a_np = np.random.uniform(size=(n, n)).astype("float16") + b_np = np.random.uniform(size=(n, n)).astype("float16") + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros((n, n), dtype="float32"), ctx) + evaluator = func.time_evaluator( + func.entry_name, ctx, number=3, repeat=1, min_repeat_ms=40 + ) + print("matmul with tensor core: %f ms" % (evaluator(a, b, c).mean * 1e3)) + + np.testing.assert_allclose( + c.asnumpy(), + np.matmul(a_np.astype("float32"), b_np.astype("float32")), + rtol=1e-4, + atol=1e-4, + ) + + +if __name__ == """__main__""": + test_tune_matmul_cpu() + test_tune_matmul_cuda() + test_tune_matmul_cuda_tensor_core() diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py new file mode 100644 index 000000000000..0bad0154a665 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import List + +from tvm.tir import ( + Evaluate, + For, + ForKind, + IndexMap, + Var, + decl_buffer, + floordiv, + floormod, +) +from tvm.tir.analysis import expr_deep_equal +from tvm.tir.schedule.analysis import suggest_index_map + + +def _make_vars(*args: str) -> List[Var]: + return [Var(arg, dtype="int32") for arg in args] + + +def _make_loops(loop_vars: List[Var], extents: List[int]) -> List[For]: + assert len(loop_vars) == len(extents) + return [ + For( + loop_var=loop_var, + min_val=0, + extent=extent, + kind=ForKind.SERIAL, + body=Evaluate(0), + ) + for loop_var, extent in zip(loop_vars, extents) + ] + + +def _assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: + iters_1 = map1.apply(map2.src_iters) + iters_2 = map2.tgt_iters + assert len(iters_1) == len(iters_2) + for iter1, iter2 in zip(iters_1, iters_2): + assert expr_deep_equal(iter1, iter2) + + +def test_suggest_index_map_simple(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8, 256]), + indices=[ + floordiv(i, 16) * 4 + floordiv(j, 16), + floormod(i, 16) * 16 + floormod(j, 16), + ], + loops=_make_loops( + loop_vars=[i, j], + extents=[32, 64], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x, y: [ + floordiv(x, 4), + floordiv(y, 16), + floormod(x, 4), + floormod(y, 16), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +def test_suggest_index_map_bijective(): + i, j = _make_vars("i", "j") + index_map = suggest_index_map( + buffer=decl_buffer(shape=[8]), + indices=[floormod(j, 4) * 2 + i], + loops=_make_loops( + loop_vars=[i, j], + extents=[2, 32], + ), + predicate=True, + ) + expected_index_map = IndexMap.from_func( + lambda x: [ + floormod(x, 2), + floordiv(x, 2), + ], + ) + _assert_equal_index_map(index_map, expected_index_map) + + +if __name__ == "__main__": + test_suggest_index_map_simple() + test_suggest_index_map_bijective() diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py new file mode 100644 index 000000000000..57071cd7ad5b --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -0,0 +1,227 @@ +import sys +import pytest +import tvm +from tvm import tir, te +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +@T.prim_func +def elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def blockize(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, (128, 128), "float32") + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + for i, j in T.grid(8, 8): + with T.block("blockized_B"): + vi, vj = T.axis.remap("SS", [i, j]) + for ii, jj in T.grid(16, 16): + with T.block("B"): + vii = T.axis.S(128, vi * 16 + ii) + vjj = T.axis.S(128, vj * 16 + jj) + B[vii, vjj] = A[vii, vjj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def blockize_schedule_1(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0_outer in range(0, 8): + for i1_outer in range(0, 8): + with T.block("blockized_B"): + vio = T.axis.S(8, i0_outer) + vjo = T.axis.S(8, i1_outer) + T.reads([A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + T.writes([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with T.block("B"): + vi = T.axis.S(128, ((vio * 16) + i0_inner)) + vj = T.axis.S(128, ((vjo * 16) + i1_inner)) + T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + with T.block("blockized_C"): + vio = T.axis.S(8, i0_outer) + vjo = T.axis.S(8, i1_outer) + T.reads([B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + T.writes([C[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]]) + for ax0 in range(0, 16): + for ax1 in range(0, 16): + with T.block("C"): + vi = T.axis.S(128, ((vio * 16) + ax0)) + vj = T.axis.S(128, ((vjo * 16) + ax1)) + T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def blockize_schedule_2(a: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block("root"): + T.reads([]) + T.writes([]) + B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0_outer in range(0, 4): + for i1_outer in range(0, 4): + for ax0 in range(0, 2): + for ax1 in range(0, 2): + with T.block("blockized_B"): + vio = T.axis.S(8, ((i0_outer * 2) + ax0)) + vjo = T.axis.S(8, ((i1_outer * 2) + ax1)) + T.reads( + [A[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] + ) + T.writes( + [B[(vio * 16) : ((vio * 16) + 16), (vjo * 16) : ((vjo * 16) + 16)]] + ) + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with T.block("B"): + vi = T.axis.S(128, ((vio * 16) + i0_inner)) + vj = T.axis.S(128, ((vjo * 16) + i1_inner)) + T.reads([A[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([B[vi : (vi + 1), vj : (vj + 1)]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_inner_1 in range(0, 32): + for i1_inner_1 in range(0, 32): + with T.block("C"): + vi = T.axis.S(128, ((i0_outer * 32) + i0_inner_1)) + vj = T.axis.S(128, ((i1_outer * 32) + i1_inner_1)) + T.reads([B[vi : (vi + 1), vj : (vj + 1)]]) + T.writes([C[vi : (vi + 1), vj : (vj + 1)]]) + C[vi, vj] = B[vi, vj] + T.float32(1) + + +@T.prim_func +def rowsum(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer( + b, + [ + 128, + ], + ) + for k, i in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_blockized(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128]) + with T.block("blockized_B"): + vko = T.axis.R(1, 0) + vio = T.axis.S(1, 0) + with T.init(): + for i1 in T.serial(0, 128): + with T.block("B_init"): + vi_init = T.axis.S(128, i1) + B[vi_init] = T.float32(0) + for i0, i1_1 in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [i0, i1_1]) + B[vi] = B[vi] + A[vi, vk] + + +def test_blockize(): + func = elementwise + # schedule + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + _ = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + tvm.ir.assert_structural_equal(blockize, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_schedule(): + func = elementwise + # test 1 + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.reverse_compute_at(C, yo) + s.blockize(s.get_loops(C)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + verify_trace_roundtrip(sch=s, mod=func) + # test 2 + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(C) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + s.blockize(xi) + s.compute_at(B, yo) + s.blockize(s.get_loops(B)[-2]) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_1) + verify_trace_roundtrip(sch=s, mod=func) + # test 3 + s = tir.Schedule(func, debug_mask="all") + B = s.get_block("B") + C = s.get_block("C") + x, y = s.get_loops(B) + xo, xi = s.split(x, factors=[None, 16]) + yo, yi = s.split(y, factors=[None, 16]) + s.reorder(xo, yo, xi, yi) + b_outer = s.blockize(xi) + xC, yC = s.get_loops(C) + xCo, xCi = s.split(xC, factors=[None, 32]) + yCo, yCi = s.split(yC, factors=[None, 32]) + s.reorder(xCo, yCo, xCi, yCi) + s.compute_at(b_outer, yCo) + tvm.ir.assert_structural_equal(s.mod["main"], blockize_schedule_2) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_blockize_init_loops(): + s = tir.Schedule(rowsum, debug_mask="all") + k, _ = s.get_loops(s.get_block("B")) + s.blockize(k) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) + verify_trace_roundtrip(sch=s, mod=rowsum) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index 240a1cc9f53b..6bb5cdd0b461 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -742,21 +742,131 @@ def read_out_of_bound(a: T.handle, c:T.handle) -> None: @T.prim_func -def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16], "float32") - B = T.alloc_buffer([16], "float32") - C = T.match_buffer(c, [16], "float32") - for j in T.serial(0, 16): - for i in T.serial(0, T.min(1, 15 - j) + 1): +def read_out_of_bound_after_compute_at(A: T.Buffer[(16,), "float32"], C: T.Buffer[(16,), "float32"]) -> None: + # body + # with T.block("root") + B = T.alloc_buffer([16], dtype="float32") + for j in T.serial(16): + for ax0 in T.serial(2): with T.block("B"): - v = T.axis.S(16, j + i) + v = T.axis.spatial(16, j + ax0) + T.reads(A[v]) + T.writes(B[v]) + T.block_attr({"require_bound_predicate":v >= 0 and v < 16}) B[v] = A[v] with T.block("C"): - v = T.axis.S(16, j) - T.reads([B[v : v + 2]]) + v = T.axis.spatial(16, j) + T.reads(B[v : v + 2]) + T.writes(C[v]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") +@T.prim_func +def tiled_pooling_cache(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + cache = T.alloc_buffer([224, 224], dtype="float32") + dache = T.alloc_buffer([224, 224], dtype="float32") + for hh, ww in T.grid(224, 224): + with T.block("cache"): + h, w = T.axis.remap("SS", [hh, ww]) + T.reads([X[h, w]]) + T.writes([cache[h, w]]) + cache[h, w] = X[h, w] + for hh, ww in T.grid(224, 224): + with T.block("dache"): + h, w = T.axis.remap("SS", [hh, ww]) + T.reads([X[h, w]]) + T.writes([dache[h, w]]) + dache[h, w] = X[h, w] + for hh_0, ww_0, hh_1, ww_1, khh, kww in T.grid(28, 28, 8, 8, 3, 3): + with T.block("compute"): + h = T.axis.spatial(224, hh_0 * 8 + hh_1) + w = T.axis.spatial(224, ww_0 * 8 + ww_1) + kh, kw = T.axis.remap("RR", [khh, kww]) + T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1], dache[h + kh - 1, w + kw - 1]]) + T.writes([Y[h, w]]) + with T.init(): + Y[h, w] = 0.0 + Y[h, w] = T.max(Y[h, w], T.if_then_else( + T.likely(1 <= h + kh, dtype="bool") and \ + T.likely(h + kh < 225, dtype="bool") and \ + T.likely(1 <= w + kw, dtype="bool") and \ + T.likely(w + kw < 225, dtype="bool"), + cache[h + kh - 1, w + kw - 1] + dache[h + kh - 1, w + kw - 1], 0.0, dtype="float32")) + + +@T.prim_func +def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + cache = T.alloc_buffer([224, 224], dtype="float32") + dache = T.alloc_buffer([224, 224], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + with T.block("cache"): + h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) + w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) + T.reads(X[h, w]) + T.writes(cache[h, w]) + T.block_attr({"require_bound_predicate":h >= 0 and h < 224 and w >= 0 and w < 224}) + cache[h, w] = X[h, w] + for ax0, ax1 in T.grid(10, 10): + with T.block("dache"): + h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) + w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) + T.reads(X[h, w]) + T.writes(dache[h, w]) + T.block_attr({"require_bound_predicate":h >= 0 and h < 224 and w >= 0 and w < 224}) + dache[h, w] = X[h, w] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + h = T.axis.spatial(224, hh_0 * 8 + hh_1) + w = T.axis.spatial(224, ww_0 * 8 + ww_1) + kh, kw = T.axis.remap("RR", [khh, kww]) + T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1], dache[h + kh - 1, w + kw - 1]]) + T.writes([Y[h, w]]) + with T.init(): + Y[h, w] = 0.0 + Y[h, w] = T.max(Y[h, w], T.if_then_else( + T.likely(1 <= h + kh, dtype="bool") and \ + T.likely(h + kh < 225, dtype="bool") and \ + T.likely(1 <= w + kw, dtype="bool") and \ + T.likely(w + kw < 225, dtype="bool"), + cache[h + kh - 1, w + kw - 1]+ dache[h + kh - 1, w + kw - 1], 0.0, dtype="float32")) + + +@T.prim_func +def floordiv_and_floormod_indices(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [16, 16]) + Y = T.match_buffer(b, [256]) + temp = T.alloc_buffer([16, 16]) + for i, j in T.grid(16, 16): + with T.block("A"): + v_i, v_j = T.axis.remap("SS", [i, j]) + temp[v_i, v_j] = X[v_j, v_i] + 1.0 + for i in T.serial(0, 256): + with T.block("B"): + v_i = T.axis.remap("S", [i]) + Y[v_i] = temp[v_i // 16, v_i % 16] + + +@T.prim_func +def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [16, 16], dtype="float32") + Y = T.match_buffer(b, [256], dtype="float32") + temp = T.alloc_buffer([16, 16], dtype="float32") + for i in T.serial(0, 16): + for j in T.serial(0, 16): + with T.block("A"): + v_i, v_j = T.axis.remap("SS", [i, j]) + temp[v_i, v_j] = X[v_j, v_i] + T.float32(1) + for ax0 in T.serial(0, 16): + with T.block("B"): + v_i = T.axis.spatial(256, i * 16 + ax0) + Y[v_i] = temp[v_i // 16, v_i % 16] + + # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -926,5 +1036,28 @@ def test_fail_all_producers_under_loop(): sch.reverse_compute_at(block, loop) +def test_compute_at_tiled_pooling_cache(): + sch = tir.Schedule(tiled_pooling_cache, debug_mask="all") + compute = sch.get_block("compute") + _, w_o, _, _, _, _ = sch.get_loops(compute) + cache = sch.get_block("cache") + dache = sch.get_block("dache") + sch.compute_at(cache, w_o) + sch.compute_at(dache, w_o) + tvm.ir.assert_structural_equal(tiled_pooling_cache_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=tiled_pooling_cache) + + +def test_reverse_compute_at_floordiv_and_floormod_indices(): + sch = tir.Schedule(floordiv_and_floormod_indices, debug_mask="all") + A = sch.get_block("A") + B = sch.get_block("B") + sch.reverse_compute_at(B, sch.get_loops(A)[0]) + tvm.ir.assert_structural_equal( + floordiv_and_floormod_indices_after_reverse_compute_at, sch.mod["main"] + ) + verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py new file mode 100644 index 000000000000..79a7aad10f25 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -0,0 +1,221 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable + +@T.prim_func +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + +@T.prim_func +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C_shared[vi, vj]]) + with T.init(): + C_shared[vi, vj] = T.float32(0) + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + with T.block("C_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(32, bx) + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 64): + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable +# fmt: on + + +def test_read_at_global_to_shared_a(): + sch = tir.Schedule(cuda_matmul, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 1, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + verify_trace_roundtrip(sch, cuda_matmul) + + +def test_read_at_global_to_shared_ab(): + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 2, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) + + +def test_read_at_local_to_shared_c(): + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.write_at(tx, block, 0, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index 5f5daa144e96..5ad366b2fa02 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -185,6 +185,34 @@ def matmul_decompose_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> N C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] +@T.prim_func +def colsum_with_vectorization(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 32], dtype="float32") + B = T.match_buffer(b, [32], dtype="float32") + for k in T.serial(0, 128): + for i in T.vectorized(0, 32): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + with T.init(): + B[vi] = T.float32(0) + B[vi] = B[vi] + A[vk, vi] + + +@T.prim_func +def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128, 32], dtype="float32") + B = T.match_buffer(b, [32], dtype="float32") + for i in T.vectorized(0, 32): + with T.block("B_init"): + vi = T.axis.S(32, i) + B[vi] = T.float32(0) + for k in T.serial(0, 128): + for i in T.vectorized(0, 32): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + B[vi] = B[vi] + A[vk, vi] + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -243,5 +271,16 @@ def test_reduction_decompose_with_annotation(): verify_trace_roundtrip(s, mod=matmul_with_annotation) +def test_reduction_decompose_with_different_for_kind(): + s = tir.Schedule(colsum_with_vectorization, debug_mask="all") + B = s.get_block("B") + k, _ = s.get_loops(B) + B_init = s.decompose_reduction(B, k) + tvm.ir.assert_structural_equal(s.mod["main"], colsum_decompose_with_vectorization) + assert s.get(B).same_as(s.get(s.get_block("B_update"))) + assert s.get(B_init).same_as(s.get(s.get_block("B_init"))) + verify_trace_roundtrip(s, mod=colsum_with_vectorization) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index f5fc5a73d038..d9ddec6795a9 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -37,40 +37,36 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: with T.block("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) - T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) - T.writes([C[vi, vj]]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @T.prim_func -def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - C_rf = T.alloc_buffer([4, 128, 128]) - +def matmul_rfactor( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], +) -> None: + C_rf = T.alloc_buffer([4, 128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): with T.block("update_rf"): - vi2_inner_inner = T.axis.S(4, i2_inner_inner) - vi = T.axis.S(128, i0) - vj = T.axis.S(128, i1) - vi2_outer = T.axis.R(4, i2_outer) - vi2_inner_outer = T.axis.R(8, i2_inner_outer) - with T.init(): - C_rf[vi2_inner_inner, vi, vj] = 0.0 - C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( - A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] - * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + vi, vj, vi2_outer, vi2_inner_outer, vi2_inner_inner = T.axis.remap( + "SSRRS", [i0, i1, i2_outer, i2_inner_outer, i2_inner_inner] ) - - for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): + with T.init(): + C_rf[vi2_inner_inner, vi, vj] = T.float32(0) + C_rf[vi2_inner_inner, vi, vj] = ( + C_rf[vi2_inner_inner, vi, vj] + + A[vi, vi2_outer * 32 + vi2_inner_outer * 4 + vi2_inner_inner] + * B[vj, vi2_outer * 32 + vi2_inner_outer * 4 + vi2_inner_inner] + ) + for i0, i1, i2_inner_inner in T.grid(128, 128, 4): with T.block("update"): - vi2_inner_inner_1, vi_1, vj_1 = T.axis.remap("RSS", [i2_inner_inner_1, i0_1, i1_1]) + vi, vj, vi2_inner_inner = T.axis.remap("SSR", [i0, i1, i2_inner_inner]) with T.init(): - C[vi_1, vj_1] = 0.0 - C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + C_rf[vi2_inner_inner, vi, vj] @T.prim_func @@ -141,24 +137,22 @@ def square_sum(a: T.handle, c: T.handle) -> None: @T.prim_func -def square_sum_rfactor(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - C = T.match_buffer(c, [16]) - C_rf = T.alloc_buffer([16, 256]) - - for i0, i1, i2 in T.grid(16, 256, 256): +def square_sum_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], C: T.Buffer[(16,), "float32"] +) -> None: + C_rf = T.alloc_buffer([16, 256], dtype="float32") + for b0, i0, j0 in T.grid(16, 256, 256): with T.block("C_rf"): - vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) + b, i, vj0 = T.axis.remap("SRS", [b0, i0, j0]) with T.init(): - C_rf[b, vi2] = 0.0 - C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) - - for i0_1, i2_1 in T.grid(16, 256): + C_rf[b, vj0] = T.float32(0) + C_rf[b, vj0] = C_rf[b, vj0] + A[b, i, vj0] * A[b, i, vj0] + for b0, j0 in T.grid(16, 256): with T.block("C"): - vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) + b, vj0 = T.axis.remap("SR", [b0, j0]) with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vj0] @T.prim_func @@ -167,51 +161,150 @@ def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) - for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 32768, 2): with T.block("C"): b = T.axis.S(16, i0) - i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) - j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) - T.reads([C[b], A[b, i, j]]) - T.writes([C[b]]) + i = T.axis.R(256, T.floordiv(i1_i2_fused_outer * 2 + i1_i2_fused_inner, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_outer * 2 + i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) - T.reads([C[b_1]]) - T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32") @T.prim_func -def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: +def square_sum_square_root_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16,), "float32"] +) -> None: + C = T.alloc_buffer([16], dtype="float32") + C_rf = T.alloc_buffer([2, 16], dtype="float32") + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 32768, 2): + with T.block("C_rf"): + b, vi1_i2_fused_outer, vi1_i2_fused_inner = T.axis.remap( + "SRS", [i0, i1_i2_fused_outer, i1_i2_fused_inner] + ) + with T.init(): + C_rf[vi1_i2_fused_inner, b] = T.float32(0) + C_rf[vi1_i2_fused_inner, b] = ( + C_rf[vi1_i2_fused_inner, b] + + A[ + b, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) // 256, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) % 256, + ] + * A[ + b, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) // 256, + (vi1_i2_fused_outer * 2 + vi1_i2_fused_inner) % 256, + ] + ) + for i0, i1_i2_fused_inner in T.grid(16, 2): + with T.block("C"): + b, vi1_i2_fused_inner = T.axis.remap("SR", [i0, i1_i2_fused_inner]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + C_rf[vi1_i2_fused_inner, b] + for i0_1 in T.serial(16): + with T.block("D"): + b_1 = T.axis.spatial(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + + +@T.prim_func +def transformed_square_sum_square_root_factor_one_1(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) - C_rf = T.alloc_buffer([1, 16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block("C_rf"): - vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0]) + with T.block("C"): + b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256)) with T.init(): - C_rf[vi1_i2_fused_inner, b] = 0.0 - C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in T.serial(0, 16): + with T.block("D"): + b_1 = T.axis.S(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + - for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): +@T.prim_func +def square_sum_square_root_factor_one_1_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16,), "float32"] +) -> None: + C = T.alloc_buffer([16], dtype="float32") + C_rf = T.alloc_buffer([1, 16], dtype="float32") + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + with T.block("C_rf"): + b = T.axis.spatial(16, i0) + i = T.axis.reduce(256, i1_i2_fused_outer // 256) + j = T.axis.reduce(256, i1_i2_fused_outer % 256) + vi1_i2_fused_inner = T.axis.spatial(1, i1_i2_fused_inner) + with T.init(): + C_rf[vi1_i2_fused_inner, b] = T.float32(0) + C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + A[b, i, j] * A[b, i, j] + for i0, i1_i2_fused_inner in T.grid(16, 1): with T.block("C"): - vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1]) + b, vi1_i2_fused_inner = T.axis.remap("SR", [i0, i1_i2_fused_inner]) with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] + C[b] = T.float32(0) + C[b] = C[b] + C_rf[vi1_i2_fused_inner, b] + for i0_1 in T.serial(16): + with T.block("D"): + b_1 = T.axis.spatial(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + - for i0_2 in T.serial(0, 16): +@T.prim_func +def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + D = T.match_buffer(d, [16]) + C = T.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): + with T.block("C"): + b = T.axis.S(16, i0) + i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) + j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) + with T.init(): + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in T.serial(0, 16): with T.block("D"): - b_2 = T.axis.S(16, i0_2) - D[b_2] = T.sqrt(C[b_2], dtype="float32") + b_1 = T.axis.S(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + + +@T.prim_func +def square_sum_square_root_factor_one_2_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], D: T.Buffer[(16,), "float32"] +) -> None: + C = T.alloc_buffer([16], dtype="float32") + C_rf = T.alloc_buffer([16, 1], dtype="float32") + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): + with T.block("C_rf"): + b = T.axis.spatial(16, i0) + i = T.axis.reduce(256, i1_i2_fused_inner // 256) + j = T.axis.reduce(256, i1_i2_fused_inner % 256) + vi1_i2_fused_outer = T.axis.spatial(1, i1_i2_fused_outer) + with T.init(): + C_rf[b, vi1_i2_fused_outer] = T.float32(0) + C_rf[b, vi1_i2_fused_outer] = C_rf[b, vi1_i2_fused_outer] + A[b, i, j] * A[b, i, j] + for i0, i1_i2_fused_outer in T.grid(16, 1): + with T.block("C"): + b, vi1_i2_fused_outer = T.axis.remap("SR", [i0, i1_i2_fused_outer]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vi1_i2_fused_outer] + for i0_1 in T.serial(16): + with T.block("D"): + b_1 = T.axis.spatial(16, i0_1) + D[b_1] = T.sqrt(C[b_1], dtype="float32") @T.prim_func @@ -229,26 +322,24 @@ def square_sum_with_annotation(a: T.handle, c: T.handle) -> None: @T.prim_func -def square_sum_with_annotation_rfactor(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - C = T.match_buffer(c, [16]) - C_rf = T.alloc_buffer([16, 256]) - - for i0, i1, i2 in T.grid(16, 256, 256): +def square_sum_with_annotation_rfactor( + A: T.Buffer[(16, 256, 256), "float32"], C: T.Buffer[(16,), "float32"] +) -> None: + C_rf = T.alloc_buffer([16, 256], dtype="float32") + for b0, i0, j0 in T.grid(16, 256, 256): with T.block("C_rf"): + b, i, vj0 = T.axis.remap("SRS", [b0, i0, j0]) T.block_attr({"test_annotation": 1}) - vi2, b, i = T.axis.remap("SSR", [i2, i0, i1]) with T.init(): - C_rf[b, vi2] = 0.0 - C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) - - for i0_1, i2_1 in T.grid(16, 256): + C_rf[b, vj0] = T.float32(0) + C_rf[b, vj0] = C_rf[b, vj0] + A[b, i, vj0] * A[b, i, vj0] + for b0, j0 in T.grid(16, 256): with T.block("C"): + b, vj0 = T.axis.remap("SR", [b0, j0]) T.block_attr({"test_annotation": 1}) - vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1]) with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + C[b] = T.float32(0) + C[b] = C[b] + C_rf[b, vj0] @T.prim_func @@ -370,24 +461,20 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: @T.prim_func -def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128]) - B = T.match_buffer(b, []) - B_rf = T.alloc_buffer([128]) - - for i in range(128): +def rowsum_zero_dim_rfactor(A: T.Buffer[(128,), "float32"], B: T.Buffer[(), "float32"]) -> None: + B_rf = T.alloc_buffer([128], dtype="float32") + for k0 in T.serial(128): with T.block("B_rf"): - vi0 = T.axis.S(128, i) + vk0 = T.axis.spatial(128, k0) with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] - - for i in range(128): + B_rf[vk0] = T.float32(0) + B_rf[vk0] = B_rf[vk0] + A[vk0] + for k0 in T.serial(128): with T.block("B"): - vi0_1 = T.axis.R(128, i) + vk0 = T.axis.reduce(128, k0) with T.init(): - B[()] = 0.0 - B[()] = B[()] + B_rf[vi0_1] + B[()] = T.float32(0) + B[()] = B[()] + B_rf[vk0] @T.prim_func @@ -405,20 +492,20 @@ def rowsum_predicate(a: T.handle, b: T.handle) -> None: @T.prim_func -def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128, 128], dtype="float32") - B = T.match_buffer(b, [128], dtype="float32") +def rowsum_predicate_rfactor( + A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"] +) -> None: B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block("B_rf"): - vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) + vi, vk_0, vk_1 = T.axis.remap("SSR", [i, k_0, k_1]) T.where(k_0 * 10 + k_1 < 128) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): with T.block("B"): - vk_0, vi = T.axis.remap("RS", [k_0, i]) + vi, vk_0 = T.axis.remap("SR", [i, k_0]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0] @@ -466,50 +553,49 @@ def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: @T.prim_func -def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: - A = T.match_buffer(a, [16, 16, 16]) - C = T.alloc_buffer([16, 16]) - D = T.alloc_buffer([16, 16]) - E = T.alloc_buffer([16, 16]) - F = T.match_buffer(f, [16, 16]) - C_rf = T.alloc_buffer([16, 16, 4]) - +def multiple_reduction_blocks_rfactor( + A: T.Buffer[(16, 16, 16), "float32"], F: T.Buffer[(16, 16), "float32"] +) -> None: + C = T.alloc_buffer([16, 16], dtype="float32") + D = T.alloc_buffer([16, 16], dtype="float32") + E = T.alloc_buffer([16, 16], dtype="float32") + C_rf = T.alloc_buffer([16, 16, 4], dtype="float32") for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): with T.block("C_rf"): - vk1o, ci, cj, vk1i = T.axis.remap("SSSR", [k1o, i, j1, k1i]) + ci, cj, vk1o, vk1i = T.axis.remap("SSSR", [i, j1, k1o, k1i]) with T.init(): - C_rf[ci, cj, vk1o] = 0.0 - C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] - for i_1 in T.serial(0, 16): - for j1_1 in T.serial(0, 16): - for k1o_1 in T.serial(0, 4): + C_rf[ci, cj, vk1o] = T.float32(0) + C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, vk1o * 4 + vk1i] + for i in T.serial(16): + for j1 in T.serial(16): + for k1o in T.serial(4): with T.block("C"): - vk1o_1, ci_1, cj_1 = T.axis.remap("RSS", [k1o_1, i_1, j1_1]) + ci, cj, vk1o = T.axis.remap("SSR", [i, j1, k1o]) with T.init(): - C[ci_1, cj_1] = 0.0 - C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] + C[ci, cj] = T.float32(0) + C[ci, cj] = C[ci, cj] + C_rf[ci, cj, vk1o] for k2o, k2i in T.grid(4, 4): with T.block("D"): - di, dj = T.axis.remap("SS", [i_1, j1_1]) - dk = T.axis.R(16, k2o * 4 + k2i) + di, dj = T.axis.remap("SS", [i, j1]) + dk = T.axis.reduce(16, k2o * 4 + k2i) with T.init(): - D[di, dj] = 0.0 - D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] - for j2 in T.serial(0, 16): + D[di, dj] = T.float32(0) + D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] + for j2 in T.serial(16): for k3o, k3i in T.grid(4, 4): with T.block("E"): - ei, ej = T.axis.remap("SS", [i_1, j2]) - ek = T.axis.R(16, k3o * 4 + k3i) + ei, ej = T.axis.remap("SS", [i, j2]) + ek = T.axis.reduce(16, k3o * 4 + k3i) with T.init(): - E[ei, ej] = 0.0 - E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] + E[ei, ej] = T.float32(0) + E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): with T.block("F"): - fi, fj = T.axis.remap("SS", [i_1, j2]) - fk = T.axis.R(16, k4o * 4 + k4i) + fi, fj = T.axis.remap("SS", [i, j2]) + fk = T.axis.reduce(16, k4o * 4 + k4i) with T.init(): - F[fi, fj] = 0.0 - F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] + F[fi, fj] = T.float32(0) + F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -548,6 +634,28 @@ def test_reduction_rfactor_square_sum_square_root(): verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) +def test_reduction_rfactor_square_sum_square_root_factor_one_1(): + s = tir.Schedule(transformed_square_sum_square_root_factor_one_1, debug_mask="all") + C = s.get_block("C") + _, _, f_i = s.get_loops(C) + rf_block = s.rfactor(f_i, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_factor_one_1_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root_factor_one_1) + + +def test_reduction_rfactor_square_sum_square_root_factor_one_2(): + s = tir.Schedule(transformed_square_sum_square_root_factor_one_2, debug_mask="all") + C = s.get_block("C") + _, f_o, _ = s.get_loops(C) + rf_block = s.rfactor(f_o, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_factor_one_2_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root_factor_one_2) + + def test_reduction_rfactor_loop_multiple_children(): s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") k, _, _ = s.get_loops(s.get_block("C")) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 5d2676e41d1c..cf1f17b8a133 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -24,7 +24,7 @@ from tvm.tir.schedule.testing import verify_trace_roundtrip -# pylint: disable=no-member,invalid-name,unused-variable +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long @T.prim_func @@ -37,7 +37,29 @@ def elementwise(a: T.handle, b: T.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -# pylint: enable=no-member,invalid-name,unused-variable +@T.prim_func +def tiled_conv2d_with_padding(inputs: T.Buffer[(1, 224, 224, 3), "float32"], weight: T.Buffer[(7, 7, 3, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"]) -> None: + PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 230, 230, 3): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(3 <= i1_1 and i1_1 < 227 and 3 <= i2_1 and i2_1 < 227, inputs[i0_1, i1_1 - 3, i2_1 - 3, i3_1], T.float32(0), dtype="float32") + for i0_0, i1_0, i2_0, i3_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 4, 1, 1, 2, 4, 1, 7, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 56, 7, 64): + with T.block("conv2d_nhwc"): + n = T.axis.spatial(1, 0) + h = T.axis.spatial(112, i1_1_1 * 56 + i1_3) + w = T.axis.spatial(112, i2_0 * 28 + i2_1_1 * 7 + i2_3) + co, rh, rw, rc = T.axis.remap("SRRR", [i3_3, i4_0, i5_0, i6_1]) + T.reads(conv2d_nhwc[n, h, w, co], PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc], weight[rh, rw, rc, co]) + T.writes(conv2d_nhwc[n, h, w, co]) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float32(0) + conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 64 * 3 + rc] * weight[rh, rw, rc, co] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long def test_sample_categorical(): @@ -116,5 +138,22 @@ def test_sample_perfect_tile_composite(): verify_trace_roundtrip(sch, mod=elementwise) +def test_sample_compute_location(): + n = 100 + sch = tir.Schedule(tiled_conv2d_with_padding, seed=42, debug_mask="all") + pad_input = sch.get_block("PadInput") + decision_dict = dict() + for _ in range(n): + _ = sch.sample_compute_location(pad_input) # pylint: disable=invalid-name + decision = sch.trace.decisions[sch.trace.insts[-1]] + decision_dict[decision] = decision_dict[decision] + 1 if decision in decision_dict else 1 + + n_candidates = 8 + expected_rate = 1.0 / n_candidates + for _, cnt in decision_dict.items(): + assert (expected_rate - 0.03) * n <= cnt <= (expected_rate + 0.03) * n + + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 84ececebbcba..fd2115bddbed 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -66,7 +66,7 @@ def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None: for i_j_k_fused in T.serial(0, (n * 16384)): with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128)) - vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, n), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, n*128), n)) vk = T.axis.S(n, T.floormod(i_j_k_fused, n)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -164,7 +164,7 @@ def elementwise_fused(a: T.handle, b: T.handle) -> None: for fused in T.serial(0, 2097152): with T.block("B"): vi = T.axis.S(128, T.floordiv(fused, 16384)) - vj = T.axis.S(128, T.floormod(T.floordiv(fused, 128), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(fused, 16384), 128)) vk = T.axis.S(128, T.floormod(fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -205,7 +205,7 @@ def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): with T.block("B"): - T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) + T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128) vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) vj = T.axis.S(128, j1) vk = T.axis.S(128, k0 * 43 + k1) @@ -223,8 +223,8 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: T.reads( [ A[ - T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), - T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floordiv(i_j_k_fused, 16384), + T.floordiv(T.floormod(i_j_k_fused, 16384), 128), T.floormod(i_j_k_fused, 128), ] ] @@ -232,15 +232,15 @@ def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None: T.writes( [ B[ - T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128), - T.floormod(T.floordiv(i_j_k_fused, 128), 128), + T.floordiv(i_j_k_fused, 16384), + T.floordiv(T.floormod(i_j_k_fused, 16384), 128), T.floormod(i_j_k_fused, 128), ] ] ) with T.block("B"): vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384)) - vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128), 128)) + vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, 16384), 128)) vk = T.axis.S(128, T.floormod(i_j_k_fused, 128)) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) @@ -343,7 +343,7 @@ def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None: with T.block("B"): vi = T.axis.S( 127, - i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1), + i * 32 + T.floordiv(j_k_fused, 128), ) vj = T.axis.S(128, T.floormod(j_k_fused, 128)) T.reads([A[vi, vj]]) diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py new file mode 100644 index 000000000000..65fd79631949 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -0,0 +1,394 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def desc_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii = T.axis.S(16, vi + i) + vjj = T.axis.S(16, vj + j) + vkk = T.axis.R(16, vk + k) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + +@T.prim_func +def intrin_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + # These access region must be explicitly stated. Otherwise the auto-completed region starts from (0, 0) instead of (vi, vj) + T.reads([A[vi: vi+16, vk: vk+16], B[vj: vj+16, vk: vk+16], C[vi:vi+16, vj:vj+16]]) + T.writes([C[vi: vi+16, vj: vj+16]]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + B[vjj, vkk] * A[vii, vkk] + + + +@T.prim_func +def lower_intrin_func(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + vi = T.axis.S(16, 0) + vj = T.axis.S(16, 0) + vk = T.axis.R(16, 0) + T.reads([C[vi:vi + 16, vj:vj + 16], A[vi:vi + 16, vk:vk + 16], B[vj:vj + 16, vk:vk + 16]]) + T.writes(C[vi:vi + 16, vj:vj + 16]) + T.evaluate(T.tvm_mma_sync(C.data, C.elem_offset // 256, + A.data, A.elem_offset // 256, + B.data, B.elem_offset // 256, + C.data, C.elem_offset // 256, + dtype="handle")) + + +@T.prim_func +def tensorized_func(a: T.handle, b: T.handle, c: T.handle) -> None: + # function attr dict + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + for i_outer, j_outer in T.grid(8, 8): + for i_inner_init, j_inner_init in T.grid(16, 16): + with T.block("init"): + vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init)) + vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init)) + C[vi_init, vj_init] = T.float32(0) + for k_outer in T.grid(8): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer]) + T.reads([C[vi*16:vi*16 + 16, vj*16:vj*16 + 16], A[vi*16:vi*16 + 16, vk*16:vk*16 + 16], B[vj*16:vj*16 + 16, vk*16:vk*16 + 16]]) + T.writes(C[vi*16:vi*16 + 16, vj*16:vj*16 + 16]) + A_elem_offset = T.var('int32') + B_elem_offset = T.var('int32') + C_elem_offset = T.var('int32') + A_sub = T.match_buffer(A[vi*16:vi*16+16, vk*16:vk*16+16], [16, 16], elem_offset=A_elem_offset) + B_sub = T.match_buffer(B[vj*16:vj*16+16, vk*16:vk*16+16], [16, 16], elem_offset=B_elem_offset) + C_sub = T.match_buffer(C[vi*16:vi*16+16, vj*16:vj*16+16], [16, 16], elem_offset=C_elem_offset) + T.evaluate( + T.tvm_mma_sync(C_sub.data, T.floordiv(C_sub.elem_offset, 256), + A_sub.data, T.floordiv(A_sub.elem_offset, 256), + B_sub.data, T.floordiv(B_sub.elem_offset, 256), + C_sub.data, T.floordiv(C_sub.elem_offset, 256), + dtype="handle")) + + +@T.prim_func +def batch_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 128, 128]) + B = T.match_buffer(b, [16, 128, 128]) + C = T.match_buffer(c, [16, 128, 128]) + + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(16, 128, 128, 128): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def tensorized_batch_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + # function attr dict + C = T.match_buffer(c, [16, 128, 128]) + B = T.match_buffer(b, [16, 128, 128]) + A = T.match_buffer(a, [16, 128, 128]) + + for n, i, j in T.grid(16, 128, 128): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + # body + for n in range(0, 16): + for i, j, k in T.grid(8, 8, 8): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + T.reads([C[vn:vn + 1, vi*16:vi*16 + 16, vj*16:vj*16 + 16], A[vn:vn + 1, vi*16:vi*16 + 16, vk*16:vk*16 + 16], + B[vn:vn + 1, vj*16:vj*16 + 16, vk*16:vk*16 + 16]]) + T.writes(C[vn:vn + 1, vi*16:vi*16 + 16, vj*16:vj*16 + 16]) + A_elem_offset = T.var('int32') + B_elem_offset = T.var('int32') + C_elem_offset = T.var('int32') + A_sub = T.match_buffer(A[vn:vn + 1, vi*16:vi*16+16,vk*16:vk*16+16], (16, 16), elem_offset=A_elem_offset) + B_sub = T.match_buffer(B[vn:vn + 1, vj*16:vj*16+16,vk*16:vk*16+16], (16, 16), elem_offset=B_elem_offset) + C_sub = T.match_buffer(C[vn:vn + 1, vi*16:vi*16+16,vj*16:vj*16+16], (16, 16), elem_offset=C_elem_offset) + T.evaluate( + T.tvm_mma_sync(C_sub.data, T.floordiv(C_sub.elem_offset, 256), + A_sub.data, T.floordiv(A_sub.elem_offset, 256), + B_sub.data, T.floordiv(B_sub.elem_offset, 256), + C_sub.data, T.floordiv(C_sub.elem_offset, 256), + dtype="handle")) + + +@T.prim_func +def batch_matmul_dot_product(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [1, 4, 4], "float32") + B = T.match_buffer(b, [1, 4, 4], "float32") + C = T.match_buffer(c, [1, 4, 4], "float32") + + t = T.var("int32") + T.attr(T.iter_var(t, None, "DataPar", ""), "pragma_import_llvm", + "; ModuleID = '/tmp/tmpur44d1nu/input0.cc'\n\ +source_filename = \"/tmp/tmpur44d1nu/input0.cc\"\n\ +target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"\n\ +target triple = \"x86_64-pc-linux-gnu\"\n\ +\n\ +; Function Attrs: noinline nounwind optnone uwtable\n\ +define dso_local i32 @vec4add(float* %0, i32 %1, float* %2, i32 %3, float* %4, i32 %5) #0 {\n\ + %7 = alloca float*, align 8\n\ + %8 = alloca i32, align 4\n\ + %9 = alloca float*, align 8\n\ + %10 = alloca i32, align 4\n\ + %11 = alloca float*, align 8\n\ + %12 = alloca i32, align 4\n\ + %13 = alloca i32, align 4\n\ + store float* %0, float** %7, align 8\n\ + store i32 %1, i32* %8, align 4\n\ + store float* %2, float** %9, align 8\n\ + store i32 %3, i32* %10, align 4\n\ + store float* %4, float** %11, align 8\n\ + store i32 %5, i32* %12, align 4\n\ + store i32 0, i32* %13, align 4\n\ + br label %14\n\ +\n\ +14: ; preds = %39, %6\n\ + %15 = load i32, i32* %13, align 4\n\ + %16 = icmp slt i32 %15, 4\n\ + br i1 %16, label %17, label %42\n\ +\n\ +17: ; preds = %14\n\ + %18 = load float*, float** %9, align 8\n\ + %19 = load i32, i32* %13, align 4\n\ + %20 = load i32, i32* %10, align 4\n\ + %21 = add nsw i32 %19, %20\n\ + %22 = sext i32 %21 to i64\n\ + %23 = getelementptr inbounds float, float* %18, i64 %22\n\ + %24 = load float, float* %23, align 4\n\ + %25 = load float*, float** %11, align 8\n\ + %26 = load i32, i32* %13, align 4\n\ + %27 = load i32, i32* %12, align 4\n\ + %28 = add nsw i32 %26, %27\n\ + %29 = sext i32 %28 to i64\n\ + %30 = getelementptr inbounds float, float* %25, i64 %29\n\ + %31 = load float, float* %30, align 4\n\ + %32 = fmul float %24, %31\n\ + %33 = load float*, float** %7, align 8\n\ + %34 = load i32, i32* %8, align 4\n\ + %35 = sext i32 %34 to i64\n\ + %36 = getelementptr inbounds float, float* %33, i64 %35\n\ + %37 = load float, float* %36, align 4\n\ + %38 = fadd float %37, %32\n\ + store float %38, float* %36, align 4\n\ + br label %39\n\ +\n\ +39: ; preds = %17\n\ + %40 = load i32, i32* %13, align 4\n\ + %41 = add nsw i32 %40, 1\n\ + store i32 %41, i32* %13, align 4\n\ + br label %14\n\ +\n\ +42: ; preds = %14\n\ + ret i32 0\n\ +}\n\ +\n\ +attributes #0 = { noinline nounwind optnone uwtable \"correctly-rounded-divide-sqrt-fp-math\"=\"false\" \"disable-tail-calls\"=\"false\" \"frame-pointer\"=\"all\" \"less-precise-fpmad\"=\"false\" \"min-legal-vector-width\"=\"0\" \"no-infs-fp-math\"=\"false\" \"no-jump-tables\"=\"false\" \"no-nans-fp-math\"=\"false\" \"no-signed-zeros-fp-math\"=\"false\" \"no-trapping-math\"=\"true\" \"stack-protector-buffer-size\"=\"8\" \"target-cpu\"=\"x86-64\" \"target-features\"=\"+cx8,+fxsr,+mmx,+sse,+sse2,+x87\" \"unsafe-fp-math\"=\"false\" \"use-soft-float\"=\"false\" }\n\ +\n\ +!llvm.module.flags = !{!0}\n\ +!llvm.ident = !{!1}\n\ +\n\ +!0 = !{i32 1, !\"wchar_size\", i32 4}\n\ +!1 = !{!\"Ubuntu clang version 11.0.0-++20200928083541+eb83b551d3e-1~exp1~20200928184208.110\"}\n\ +\n\ + ") + + for n, i, j in T.grid(1, 4, 4): + with T.block("init"): + vn, vi, vj = T.axis.remap("SSS", [n, i, j]) + C[vn, vi, vj] = T.float32(0) + + for n, i, j, k in T.grid(1, 4, 4, 4): + with T.block("update"): + vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k]) + C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk] + + +@T.prim_func +def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,)) + B = T.match_buffer(b, (4,)) + C = T.match_buffer(c, (1,)) + + with T.block("root"): + v0 = T.axis.R(4, 0) + for i in range(0, 4): + with T.block("update"): + vi = T.axis.R(4, v0 + i) + C[0] = C[0] + A[vi] * B[vi] + + +@T.prim_func +def dot_product_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4,), offset_factor=1) + B = T.match_buffer(b, (4,), offset_factor=1) + C = T.match_buffer(c, (1,), offset_factor=1) + + with T.block("root"): + v0 = T.axis.R(4, 0) + T.reads([C[0 : 1], A[v0 : v0 + 4], B[v0 : v0 + 4]]) + T.writes([C[0 : 1]]) + T.evaluate(T.call_extern("vec4add", C.data, C.elem_offset, A.data, A.elem_offset, B.data, B.elem_offset, dtype="int32")) + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg +# fmt: on + +# pylint: disable=invalid-name + + +tir.TensorIntrin.register("test_identity_intrin", desc_func, intrin_func) +tir.TensorIntrin.register("test_mma_intrin", desc_func, lower_intrin_func) +tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_impl) + + +def test_tensorize_gemm(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, "test_identity_intrin") + + func = tvm.build(s.mod["main"]) + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128)).astype("float32")) + func(a, b, c) + tvm.testing.assert_allclose(c.numpy(), np.dot(a_np, b_np.transpose()), rtol=1e-6) + + +def test_tensorize_buffer_bind(): + func = matmul + # schedule + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.decompose_reduction(update, ko) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_func, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_high_dim_tensorize(): + func = batch_matmul + s = tir.Schedule(func, debug_mask="all") + update = s.get_block("update") + _, i, j, k = s.get_loops(update) + io, ii = s.split(i, factors=[None, 16]) + jo, ji = s.split(j, factors=[None, 16]) + ko, ki = s.split(k, factors=[None, 16]) + s.reorder(io, jo, ko, ii, ji, ki) + s.tensorize(ii, "test_mma_intrin") + tvm.ir.assert_structural_equal(tensorized_batch_matmul, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=batch_matmul) + + +@pytest.mark.skip("failed") +def test_tensorize_dot_product(): + func = batch_matmul_dot_productt + s = tir.Schedule(func, debug_mask="all") + C = s.get_block("update") + _, _, _, k = s.get_loops(C) + _, ki = s.split(k, factors=[None, 4]) + s.tensorize(ki, "test_dot_product_intrin") + target = "llvm" + ctx = tvm.device(target, 0) + a_np = np.random.uniform(size=(1, 4, 4)).astype("float32") + b_np = np.random.uniform(size=(1, 4, 4)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((1, 4, 4), dtype="float32"), ctx) + func = tvm.build(s.mod["main"], target=target) + func(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy(), b.numpy().transpose(0, 2, 1)), + rtol=1e-5, + ) + verify_trace_roundtrip(sch=s, mod=func) + + +if __name__ == "__main__": + test_tensorize_gemm() + test_tensorize_buffer_bind() + test_high_dim_tensorize() + # test_tensorize_dot_product() diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index f1c97c57b2ff..1923eb23af5b 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -82,6 +82,15 @@ def _make_compute_inline(input): # pylint: disable=redefined-builtin ) +def _make_split(inputs, outputs): # pylint: disable=redefined-builtin + return Instruction( + kind=InstructionKind.get("Split"), + inputs=inputs, + attrs=[], + outputs=outputs, + ) + + def _make_enter_postproc(): return Instruction( kind=InstructionKind.get("EnterPostproc"), @@ -129,6 +138,17 @@ def _make_trace_3(b0, b1, add_postproc): # pylint: disable=invalid-name return Trace(insts=insts, decisions={}) +def _make_trace_4(b0, l1, l2, l3): # pylint: disable=invalid-name + return Trace( + insts=[ + _make_get_block(name="B", output=b0), + _make_get_loops(input=b0, outputs=[l1]), + _make_split([l1, None, 32], [l2, l3]), + ], + decisions={}, + ) + + def test_trace_construct_1(): trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV()) assert str(trace) == "\n".join( @@ -235,6 +255,17 @@ def test_trace_simplified_2(): ) +def test_trace_simplified_3(): + trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False) + assert str(trace) == "\n".join( + ( + 'b0 = sch.get_block(name="B", func_name="main")', + "l1, = sch.get_loops(block=b0)", + "l2, l3 = sch.split(loop=l1, factors=[None, 32])", + ) + ) + + def test_apply_json_to_schedule_1(): trace = _make_trace_2(BlockRV()) json_obj = trace.as_json() diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py new file mode 100644 index 000000000000..0962e147ff96 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + +def packed_index_map_func(m, n): + return m // 16, n // 16, m % 16, n % 16 + + +@T.prim_func +def two_elementwise(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((8, 8, 16, 16), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 + + +@T.prim_func +def two_elementwise_transformed_input_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (8, 8, 16, 16), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi // 16, vj // 16, vi % 16, vj % 16] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def two_elementwise_transformed_output_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (8, 8, 16, 16), "float32") + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi // 16, vj // 16, vi % 16, vj % 16] = B[vi, vj] + 1.0 + + +@T.prim_func +def permuted_shared_memory(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + A_shared = T.alloc_buffer((128, 128), scope="shared") + for i0, j0, in T.grid(32, 4): + for fused_i1_j1 in T.thread_binding(0, 32, 'threadIdx.x'): + for j2 in T.vectorized(0, 4): + with T.block("A_shared"): + vi = T.axis.S(128, i0 * 4 + fused_i1_j1 // 8) + vj = T.axis.S(128, j0 * 32 + fused_i1_j1 % 8 * 4 + j2) + A_shared[vi, vj] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_shared[vi, vj] + 1.0 + + +@T.prim_func +def permuted_shared_memory_transformed(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + C = T.match_buffer(c, (128, 128)) + A_shared = T.alloc_buffer((32, 4, 4, 32), scope="shared") + for i0, j0, in T.grid(32, 4): + for fused_i1_j1 in T.thread_binding(0, 32, 'threadIdx.x'): + for j2 in T.vectorized(0, 4): + with T.block("A_shared"): + vi = T.axis.S(128, i0 * 4 + fused_i1_j1 // 8) + vj = T.axis.S(128, j0 * 32 + fused_i1_j1 % 8 * 4 + j2) + A_shared[vi // 4, vj // 32, vi % 4, (((vj % 32) // 8) ^ (vi % 4)) + vj % 8] = A[vi, vj] + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = A_shared[vi // 4, vj // 32, vi % 4, (((vj % 32) // 8) ^ (vi % 4)) + vj % 8] + 1.0 + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks +# fmt: on + + +def test_two_elementwise_transform_intermediate_buffer(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("B") + sch.transform_layout(block, 0, False, packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +def test_two_elementwise_transform_input_buffer(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("B") + sch.transform_layout(block, 0, True, packed_index_map_func) + print(sch.mod['main'].script()) + tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +def test_two_elementwise_transform_output_buffer(): + sch = tir.Schedule(two_elementwise, debug_mask="all") + block = sch.get_block("C") + sch.transform_layout(block, 0, False, packed_index_map_func) + tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=two_elementwise) + + +@pytest.mark.skip("xor is not supported by IntSet") +def test_permuted_layout(): + sch = tir.Schedule(permuted_shared_memory, debug_mask="all") + block = sch.get_block("A_shared") + sch.transform_layout(block, 0, False, + lambda i, j: (i // 4, j // 32, i % 4, (((j % 32) // 8) ^ (i % 4)) + j % 8)) + tvm.ir.assert_structural_equal(permuted_shared_memory_transformed, sch.mod['main']) + verify_trace_roundtrip(sch=sch, mod=permuted_shared_memory) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_apply_block_bound_predicate.py b/tests/python/unittest/test_tir_transform_apply_block_bound_predicate.py new file mode 100644 index 000000000000..7e651fb7f1ea --- /dev/null +++ b/tests/python/unittest/test_tir_transform_apply_block_bound_predicate.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir, te +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.ApplyBlockBoundPredicate()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_print(original): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.LowerCrossThreadReduction()(mod) + mod = tvm.tir.transform.LowerInitBlock()(mod) + mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod["main"].script()) + mod = tvm.tir.transform.ApplyBlockBoundPredicate()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod["main"].script()) + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks + + +@T.prim_func +def read_out_of_bound_after_compute_at(A: T.Buffer[(16,), "float32"], C: T.Buffer[(16,), "float32"]) -> None: + B = T.alloc_buffer([16], dtype="float32") + for j in T.serial(16): + for ax0 in T.serial(2): + with T.block("B"): + v = T.axis.spatial(16, j + ax0) + T.reads(A[v]) + T.writes(B[v]) + T.block_attr({"require_bound_predicate":v >= 0 and v < 16}) + B[v] = A[v] + with T.block("C"): + v = T.axis.spatial(16, j) + T.reads(B[v : v + 2]) + T.writes(C[v]) + C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") + + +@T.prim_func +def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + cache = T.alloc_buffer([224, 224], dtype="float32") + dache = T.alloc_buffer([224, 224], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + with T.block("cache"): + h = T.axis.spatial(224, hh_0 * 8 + ax0 - 1) + w = T.axis.spatial(224, ww_0 * 8 + ax1 - 1) + T.reads(X[h, w]) + T.writes(cache[h, w]) + T.block_attr({"require_bound_predicate":h >= 0 and h < 224 and w >= 0 and w < 224}) + cache[h, w] = X[h, w] + for ax0, ax1 in T.grid(10, 10): + with T.block("dache"): + h = T.axis.spatial(224, hh_0 * 8 + ax0 - 1) + w = T.axis.spatial(224, ww_0 * 8 + ax1 - 1) + T.reads(X[h, w]) + T.writes(dache[h, w]) + T.block_attr({"require_bound_predicate":h >= 0 and h < 224 and w >= 0 and w < 224}) + dache[h, w] = X[h, w] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + h = T.axis.spatial(224, hh_0 * 8 + hh_1) + w = T.axis.spatial(224, ww_0 * 8 + ww_1) + kh, kw = T.axis.remap("RR", [khh, kww]) + T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1], dache[h + kh - 1, w + kw - 1]]) + T.writes([Y[h, w]]) + with T.init(): + Y[h, w] = 0.0 + Y[h, w] = T.max(Y[h, w], T.if_then_else( + T.likely(1 <= h + kh, dtype="bool") and \ + T.likely(h + kh < 225, dtype="bool") and \ + T.likely(1 <= w + kw, dtype="bool") and \ + T.likely(w + kw < 225, dtype="bool"), + cache[h + kh - 1, w + kw - 1]+ dache[h + kh - 1, w + kw - 1], 0.0, dtype="float32")) + + +@T.prim_func +def batch_norm_after_compute_at(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + for i0_0 in T.serial(1): + with T.block(): + T.reads(A[0 : 64, 0 : 256, 0 : 256]) + T.writes(D[0 : 64]) + C = T.alloc_buffer([1], dtype="float32") + for ax0, ax1, ax2 in T.grid(64, 256, 256): + with T.block("C"): + b = T.axis.spatial(1, ax0) + i, j = T.axis.remap("RR", [ax1, ax2]) + T.reads(C[b], A[b, i, j]) + T.writes(C[b]) + T.block_attr({"require_bound_predicate":b >= 0 and b < 1}) + if i == 0 and j == 0: + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.spatial(1, i0_1) + T.where(i0_1 < 1) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + + +@T.prim_func +def transformed_batch_norm(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None: + for i0_0 in T.serial(1): + with T.block(): + T.reads(A[0 : 64, 0 : 256, 0 : 256]) + T.writes(D[0 : 64]) + C = T.alloc_buffer([1], dtype="float32") + for ax0, ax1, ax2 in T.grid(1, 256, 256): + with T.block("C"): + b = T.axis.spatial(1, 0) + i, j = T.axis.remap("RR", [ax1, ax2]) + T.reads(C[b], A[b, i, j]) + T.writes(C[b]) + if i == 0 and j == 0: + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("D"): + b = T.axis.spatial(1, i0_1) + T.where(i0_1 < 1) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks +# fmt: on + + +def test_read_out_of_bound(): + # This IR should not be mutated in this pass. + _check(read_out_of_bound_after_compute_at, read_out_of_bound_after_compute_at) + + +def test_tiled_pooling_cache(): + # This IR should not be mutated in this pass. + _check(tiled_pooling_cache_after_compute_at, tiled_pooling_cache_after_compute_at) + + +def test_batch_norm(): + _check(batch_norm_after_compute_at, transformed_batch_norm) + + +def test_lower_te(): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + mod = tvm.tir.transform.ApplyBlockBoundPredicate()(orig_mod) + tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + + +if __name__ == "__main__": + test_read_out_of_bound() + test_tiled_pooling_cache() + test_batch_norm() + test_lower_te() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 57c87e5dedf4..80e50bdcdaba 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -220,7 +220,7 @@ def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None: with T.block(): T.reads(A[i * 8 : i * 8 + 8]) T.writes(C[i * 8 : i * 8 + 8]) - B = T.alloc_buffer((8,), "float32") + B = T.alloc_buffer((T.min(n, 1) * 8,), "float32") for j in range(0, 8): with T.block() as []: T.reads(A[i * 8 + j]) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index ca3d4aa70d0b..a236d1610102 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -254,6 +254,64 @@ def annotated_loops(a: T.handle) -> None: A[i] = 0.0 +@T.prim_func +def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None: + X = T.match_buffer(a, [224, 224], dtype="float32") + Y = T.match_buffer(b, [224, 224], dtype="float32") + # body + # with T.block("root") + cache = T.alloc_buffer([10, 10], dtype="float32") + dache = T.alloc_buffer([10, 10], dtype="float32") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + with T.block("cache"): + T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.writes(cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224}) + cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] + for ax0, ax1 in T.grid(10, 10): + with T.block("dache"): + T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.writes(dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]) + T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224}) + dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + with T.block("compute"): + T.reads(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1], cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1], dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1]) + T.writes(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1]) + Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1] = T.max(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1], + T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool") + and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool") + and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool") + and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"), + cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1] + + dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1], + T.float32(0), dtype="float32")) + + +@T.prim_func +def flattened_tiled_pooling_cache_after_compute_at(X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"]) -> None: + cache = T.allocate([100], "float32", "global") + dache = T.allocate([100], "float32", "global") + for hh_0, ww_0 in T.grid(28, 28): + for ax0, ax1 in T.grid(10, 10): + if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225: + T.store(cache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True) + for ax0, ax1 in T.grid(10, 10): + if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225: + T.store(dache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True) + for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): + T.store(Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1, + T.max(T.load("float32", Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1), + T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool") + and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool") + and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool") + and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"), + T.load("float32", cache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11) + + T.load("float32", dache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11), + T.float32(0), dtype="float32")), True) + + def test_elementwise(): _check(compacted_elementwise_func, flattened_elementwise_func) @@ -305,6 +363,10 @@ def test_annotated_loops(): tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) +def test_bound_predicate(): + _check(tiled_pooling_cache_after_compute_at, flattened_tiled_pooling_cache_after_compute_at) + + if __name__ == "__main__": test_elementwise() test_gpu_workload() diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py new file mode 100644 index 000000000000..1c9b69665d1c --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -0,0 +1,741 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm import tir, te, TVMError +from tvm.script import tir as T + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.InjectSoftwarePipeline()(mod) + mod = tvm.tir.transform.Simplify()(mod) + print(mod['main'].script()) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +def _check_error(func): + mod = tvm.IRModule.from_expr(func) + with pytest.raises(ValueError): + tvm.tir.transform.InjectSoftwarePipeline()(mod) + + +@T.prim_func +def simple_compute(a: T.handle, c: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + C = T.match_buffer(c, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1], 'software_pipeline_order': [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_simple_compute(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16]]) + T.writes([C[tx, 0:16]]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0]]) + T.writes([B[0, tx, 0]]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16], B[0:2, tx, 0]]) + T.writes([B[0:2, tx, 0], C[tx, 0:15]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1]]) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads([B[i % 2, tx, 0]]) + T.writes([C[tx, i]]) + C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 0]]) + T.writes([C[tx, 15]]) + C[tx, 15] = B[1, tx, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_simple(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, + annotations={"software_pipeline_stage": [0, 1, 1, 1], + "software_pipeline_order": [0, 1, 2, 3]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1] + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_simple(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + # body + # with T.block("root") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:15, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:15, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 0]]) + T.writes([B[0, tx, i, 0]]) + B[0, tx, i, 0] = A_shared[i % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 0:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 0]]) + T.writes([B[0, tx, 15, 0]]) + B[0, tx, 15, 0] = A_shared[1, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_prefetch_inner(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 1, 1], "software_pipeline_order": [0, 2, 1, 3]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_shared[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_shared[tx, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_prefetch_inner(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([2, 16, 1, 16], dtype="float32", scope="shared") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[0, tx, 0, 0]]) + T.writes([A_shared[0, tx, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[0, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[0, tx, 0, j]]) + A_shared[0, tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_shared[0, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16, 0:16], A_shared[0:2, tx, 0:16, 0:16], B[0:2, tx, 0:15, 0]]) + T.writes([A_shared[0:2, tx, 0, 0:16], B[0:2, tx, 0:16, 0], C[tx, 0:15, 0:16]]) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[(i + 1) % 2, tx, 0, j]]) + A_shared[(i + 1) % 2, tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_shared[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_shared[ + i % 2, tx, 0, j + 1 + ] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_shared[(i + 1) % 2, tx, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_shared[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_shared[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_shared[1, tx, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_interleaving(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_interleaving(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, j]]) + A_local[0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def nested_pipeline_double_buffer(a: T.handle, c: T.handle): + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 0, 0, 1, 1], "software_pipeline_order": [0, 2, 3, 1, 4]}): + with T.block(): + T.reads(A[tx, i, 0:16]) + T.writes(C[tx, i, 0:16]) + A_shared = T.alloc_buffer((16, 1, 16), dtype="float32", scope="shared") + A_local = T.alloc_buffer((1, 1, 16), dtype="float32", scope="local") + for j in T.serial(0, 16): + with T.block(): + T.reads(A[tx, i, j]) + T.writes(A_shared[tx, 0, j]) + A_shared[tx, 0, j] = A[tx, i, j] + for j in T.serial(0, 16): + with T.block(): + T.block_attr({"double_buffer_scope": 0}) + T.reads(A_shared[tx, 0, j]) + T.writes(A_local[0, 0, j]) + A_local[0, 0, j] = A_shared[tx, i, j] + for j in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + }, + ): + with T.block(): + T.reads(A_local[0, 0, j]) + T.writes(C[tx, i, j]) + B = T.alloc_buffer((16, 1, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A_local[tx, i, j]) + T.writes(B[tx, i, 0]) + B[tx, i, 0] = A_local[0, 0, j] * T.float32(2) + with T.block(): + T.reads(B[tx, i, 0]) + T.writes(C[tx, i, j]) + C[tx, i, j] = B[tx, i, 0] + T.float32(1) + + +@T.prim_func +def transformed_nested_pipeline_double_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16], dtype="float32") + C = T.match_buffer(c, [16, 16, 16], dtype="float32") + # body + # with T.block("root") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16, 0:16]]) + T.writes([C[tx, 0:16, 0:16]]) + A_shared = T.alloc_buffer([16, 1, 16], dtype="float32", scope="shared") + A_local = T.alloc_buffer([2, 1, 1, 16], dtype="float32", scope="local") + B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]]) + T.writes([A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0, 0]]) + with T.block(): + T.reads([A[tx, 0, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, 0, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, 0, j] + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[0, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[0, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[0, 0, 0, j] = A_shared[tx, 0, j] + with T.block(): + T.reads([A_local[0, tx, 0, 0]]) + T.writes([B[0, tx, 0, 0]]) + B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads( + [ + A[tx, 1:16, 0:16], + A_local[0:2, tx, 0:16, 0:16], + B[0:2, tx, 0:15, 0], + A_shared[tx, 0, 0:16], + ] + ) + T.writes( + [ + A_shared[tx, 0, 0:16], + B[0:2, tx, 0:16, 0], + C[tx, 0:15, 0:16], + A_local[0:2, 0, 0, 0:16], + ] + ) + for i in T.serial(0, 15): + with T.block(): + T.reads([A[tx, i + 1, 0:16]]) + T.writes([A_shared[tx, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A[tx, i + 1, j]]) + T.writes([A_shared[tx, 0, j]]) + A_shared[tx, 0, j] = A[tx, i + 1, j] + with T.block(): + T.reads([A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]]) + T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[i % 2, tx, i, j + 1]]) + T.writes([B[(j + 1) % 2, tx, i, 0]]) + B[(j + 1) % 2, tx, i, 0] = A_local[i % 2, 0, 0, j + 1] * T.float32( + 2 + ) + with T.block(): + T.reads([B[j % 2, tx, i, 0]]) + T.writes([C[tx, i, j]]) + C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_shared[tx, 0, 0:16]]) + T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]]) + for j in T.serial(0, 16): + with T.block(): + T.reads([A_shared[tx, 0, j]]) + T.writes([A_local[(i + 1) % 2, 0, 0, j]]) + T.block_attr({"double_buffer_scope": 0}) + A_local[(i + 1) % 2, 0, 0, j] = A_shared[tx, i + 1, j] + with T.block(): + T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]]) + T.writes([B[0, tx, i + 1, 0]]) + B[0, tx, i + 1, 0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2) + with T.block(): + T.reads([B[1, tx, i, 0]]) + T.writes([C[tx, i, 15]]) + C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1) + with T.block(): + T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]]) + with T.block(): + T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]]) + T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]]) + for j in T.serial(0, 15): + with T.block(): + T.reads([A_local[1, tx, 15, j + 1]]) + T.writes([B[(j + 1) % 2, tx, 15, 0]]) + B[(j + 1) % 2, tx, 15, 0] = A_local[1, 0, 0, j + 1] * T.float32(2) + with T.block(): + T.reads([B[j % 2, tx, 15, 0]]) + T.writes([C[tx, 15, j]]) + C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 15, 0]]) + T.writes([C[tx, 15, 15]]) + C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_incorrect_reorder(a: T.handle, d: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + D = T.match_buffer(d, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [0, 2, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_conflicting_order(a: T.handle, d: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + D = T.match_buffer(d, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1, 1], "software_pipeline_order": [ 0, 1, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(D[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + C = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, 0]) + C[tx, 0] = B[tx, 0] + T.float32(2) + with T.block(): + T.reads(C[tx, 0]) + T.writes(D[tx, i]) + D[tx, i] = C[tx, 0] + T.float32(1) + + +@T.prim_func +def simple_compute_missing_annotation(a: T.handle, c: T.handle): + A = T.match_buffer(a, (16, 16), dtype="float32") + C = T.match_buffer(c, (16, 16), dtype="float32") + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial(0, 16, annotations={"software_pipeline_stage": [0, 1]}): + with T.block(): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + + +def test_simple_compute(): + _check(simple_compute, transformed_simple_compute) + + +def test_nest_pipeline_simple(): + _check(nested_pipeline_simple, transformed_nested_pipeline_simple) + + +def test_nest_pipeline_prefetch_inner(): + _check(nested_pipeline_prefetch_inner, transformed_nested_pipeline_prefetch_inner) + + +def test_nest_pipeline_interleaving(): + _check(nested_pipeline_interleaving, transformed_nested_pipeline_interleaving) + + +def test_nest_pipeline_double_buffer(): + _check(nested_pipeline_double_buffer, transformed_nested_pipeline_double_buffer) + + +# def test_error_reorder(): +# _check_error(simple_compute_incorrect_reorder) + + +# def test_error_conflicting_order(): +# _check_error(simple_compute_conflicting_order) + + +def test_error_missing_annotation(): + _check_error(simple_compute_missing_annotation) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 4fa3ab0c550c..5b3d7283f14f 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -327,6 +327,162 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: B[vi] = reduce_temp0[0] +@T.prim_func +def single_reduction_loop_with_block_predicate( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] +) -> None: + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) + T.writes(T_softmax_maxelem_shared[i0_1]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_2 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads( + T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2] + ) + T.writes(T_softmax_expsum_shared[i0_2]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_3 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_1) + T.where(i1_0 * 512 + i1_1 < 256) + T.reads( + A[i0_3, i1], T_softmax_maxelem_shared[i0_3], T_softmax_expsum_shared[i0_3] + ) + T.writes(T_softmax_norm[i0_3, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_3, i1] = ( + T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3], dtype="float32") + / T_softmax_expsum_shared[i0_3] + ) + + +@T.prim_func +def lowered_single_reduction_loop_with_block_predicate( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] +) -> None: + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + cross_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + cross_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem_in_thread_init"): + T.reads() + T.writes(in_thread_0[0]) + in_thread_0[0] = T.float32(-3.4028234663852886e38) + with T.block("T_softmax_maxelem_in_thread"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads(A[i0_1, k], in_thread_0[0]) + T.writes(in_thread_0[0]) + in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) + with T.block("T_softmax_maxelem_cross_thread"): + T.reads(in_thread_0[0]) + T.writes(cross_thread_0[0]) + T.attr( + T.comm_reducer( + lambda x, y: T.max(x, y), [T.float32(-3.4028234663852886e38)] + ), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_0[0], + True, + cross_thread_0.data, + ax1_1, + dtype="handle", + ) + ) + with T.block("T_softmax_maxelem_write_back"): + i0_2 = T.axis.spatial(256, i0) + T.reads(cross_thread_0[0]) + T.writes(T_softmax_maxelem_shared[i0_2]) + T_softmax_maxelem_shared[i0_2] = cross_thread_0[0] + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_expsum_in_thread_init"): + T.reads() + T.writes(in_thread_1[0]) + in_thread_1[0] = T.float32(0) + with T.block("T_softmax_expsum_in_thread"): + i0_3 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, ax1_1) + T.where(ax1_0 * 512 + ax1_1 < 256) + T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) + T.writes(in_thread_1[0]) + in_thread_1[0] = in_thread_1[0] + T.exp( + A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" + ) + with T.block("T_softmax_expsum_cross_thread"): + T.reads(in_thread_1[0]) + T.writes(cross_thread_1[0]) + T.attr( + T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(1), + in_thread_1[0], + True, + cross_thread_1.data, + ax1_1, + dtype="handle", + ) + ) + with T.block("T_softmax_expsum_write_back"): + i0_4 = T.axis.spatial(256, i0) + T.reads(cross_thread_1[0]) + T.writes(T_softmax_expsum_shared[i0_4]) + T_softmax_expsum_shared[i0_4] = cross_thread_1[0] + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_5 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_1) + T.where(i1_0 * 512 + i1_1 < 256) + T.reads( + A[i0_5, i1], T_softmax_maxelem_shared[i0_5], T_softmax_expsum_shared[i0_5] + ) + T.writes(T_softmax_norm[i0_5, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_5, i1] = ( + T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5], dtype="float32") + / T_softmax_expsum_shared[i0_5] + ) + + @T.prim_func def reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") @@ -686,6 +842,13 @@ def test_with_block_predicate(): _check(with_block_predicate, lowered_with_block_predicate) +def test_single_reduction_loop_with_block_predicate(): + _check( + single_reduction_loop_with_block_predicate, + lowered_single_reduction_loop_with_block_predicate, + ) + + def test_reducer_max(): _check(reducer_max, lowered_reducer_max) diff --git a/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py new file mode 100644 index 000000000000..26e1e259ffbc --- /dev/null +++ b/tests/python/unittest/test_tir_transform_memhammer_lower_auto_copy.py @@ -0,0 +1,395 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import te +from tvm.script import tir as T +import sys +import pytest + + +@tvm.script.ir_module +class Transpose: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 16): + A_shared_dyn[ax1, ax0] = A[ax0, ax1] + with T.block("B"): + for ax1, ax0 in T.grid(16, 128): + T.block_attr({"auto_copy": 1}) + B[ax1, ax0] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax1, ax0 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax1, ax0] + + +@tvm.script.ir_module +class GlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16, "local_stage": True}) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class SharedToWmma: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float16", scope="shared.dyn") + A_wmma = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + with T.block("A_wmma"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + A_wmma[ax0, ax1] = A_shared_dyn[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToShared: + @T.prim_func + def main() -> None: + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + C_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") + with T.block("C_shared"): + T.block_attr({"auto_copy": 1}) + for ax0, ax1 in T.grid(128, 128): + C_shared[ax0, ax1] = C_accum[ax0, ax1] + + +@tvm.script.ir_module +class WmmaToGlobal: + @T.prim_func + def main(c: T.handle) -> None: + C = T.match_buffer(c, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution": True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + with T.block("C_global"): + T.block_attr({"auto_copy": 1, "vector_bytes": 16}) + for ax0, ax1 in T.grid(128, 128): + C[bx * 128 + ax0, by * 128 + ax1] = C_accum[ax0, ax1] + +@tvm.script.ir_module +class TransformedGlobalToShared: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") + with T.block("A_shared"): + T.block_attr({"auto_copy":1, "vector_bytes":16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + +@tvm.script.ir_module +class TransformedSharedToGlobal: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[129, 1], scope="shared.dyn") + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0 : 128, 0 : 128]) + for ax0, ax1 in T.grid(128, 128): + A_shared_dyn[ax1, ax0] = A[bx * 128 + ax0, by * 128 + ax1] + with T.block("B"): + T.block_attr({"auto_copy":1, "vector_bytes":16}) + for outer in T.serial(16): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + B[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128] + +@tvm.script.ir_module +class TransformedGlobalToSharedWithLocalStage: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [1024, 1024]) + B = T.match_buffer(b, [1024, 1024]) + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") + with T.block("A_shared"): + T.reads(A[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.writes(A_shared_dyn[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1, "local_stage":True, "vector_bytes":16}) + A_local = T.alloc_buffer([16, 4], dtype="float32", scope="local") + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 16, 1, 1, 1): + for vec in T.vectorized(4): + A_local[ax0 * 16 + ax1 + ax2, (ax3 + ax4) * 4 + vec] = A[((bx % 8 + ax0) * 16 + ax1) * 8 + (ty_1 % 128 + ax2), ((by % 8 + ax3) * 32 + (tx % 32 + ax4)) * 4 + vec] + for serial in T.serial(16): + for vec in T.vectorized(4): + A_shared_dyn[(((serial * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((serial * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A_local[(serial * 8 + (tx * 4 + vec) // 128 + ty_1) % 128 // 8 + (((tx * 4 + vec) // 128 + ty_1) % 8 - ty_1 % 128), ((tx * 4 + vec) % 128 // 4 - tx % 32) * 4 + vec % 4] + with T.block("B"): + for ax0, ax1 in T.grid(128, 128): + B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] + +@tvm.script.ir_module +class TransformedSharedToWmma: + @T.prim_func + def main() -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + A_shared_dyn = T.alloc_buffer([128, 128], dtype="float16", strides=[136, 1], scope="shared.dyn") + A_wmma = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + with T.block("C_shared"): + T.reads(A_shared_dyn[0 : 128, 0 : 128]) + T.writes(A_wmma[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_load"): + T.reads(A_shared_dyn[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(A_wmma[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + src = T.match_buffer(A_shared_dyn[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float16", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + tgt = T.match_buffer(A_wmma[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16) + T.evaluate(T.tvm_load_matrix_sync(tgt.data, 16, 16, 16, tgt.elem_offset // 256 + tgt.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), src.data, src.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle")) + +@tvm.script.ir_module +class TransformedWmmaToShared: + @T.prim_func + def main() -> None: + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + C_shared = T.alloc_buffer([128, 128], dtype="float32", strides=[136, 1], scope="shared.dyn") + with T.block("A_wmma"): + T.reads(C_accum[0 : 128, 0 : 128]) + T.writes(C_shared[0 : 128, 0 : 128]) + T.block_attr({"auto_copy":1}) + for ax00, ax10 in T.grid(8, 8): + with T.block("wmma_store"): + T.reads(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(C_shared[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + src = T.match_buffer(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + tgt = T.match_buffer(C_shared[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(src.data, 16, 16, 16, src.elem_offset // 256 + src.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), tgt.data, tgt.elem_offset, s1 * 16, 2, dtype="handle"), s1, "row_major", dtype="handle")) + +@tvm.script.ir_module +class TransformedWmmaToGlobal: + @T.prim_func + def main(c: T.handle) -> None: + C = T.match_buffer(c, [1024, 1024]) + s0 = T.var("int32") + s1 = T.var("int32") + # body + with T.block("root"): + T.block_attr({"warp_execution":True}) + for bx in T.thread_binding(8, thread="blockIdx.x"): + for by in T.thread_binding(8, thread="blockIdx.y"): + for ty in T.thread_binding(8, thread="threadIdx.y"): + with T.block(): + C_accum = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + with T.block("C_global"): + T.reads(C_accum[0 : 128, 0 : 128]) + T.writes(C[bx * 128 : bx * 128 + 128, by * 128 : by * 128 + 128]) + T.block_attr({"auto_copy":1, "vector_bytes":16}) + C_shared_dyn = T.alloc_buffer([16, 128], dtype="float32", strides=[136, 1], scope="shared.dyn") + for ax00 in T.serial(8): + for ax10 in T.serial(8): + with T.block("wmma_store"): + T.reads(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16]) + T.writes(C_shared_dyn[((ax00 // 8 + bx) % 8 - bx % 8 + (ax00 % 8 - ax00 % 64)) * 16 : ((ax00 // 8 + bx) % 8 - bx % 8 + (ax00 % 8 - ax00 % 64)) * 16 + 16, (((ax10 // 8 + by) % 8 - by % 8) * 8 + ax10 % 8) * 16 : (((ax10 // 8 + by) % 8 - by % 8) * 8 + ax10 % 8) * 16 + 16]) + src = T.match_buffer(C_accum[ax00 * 16 : ax00 * 16 + 16, ax10 * 16 : ax10 * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16) + tgt = T.match_buffer(C_shared_dyn[((ax00 // 8 + bx) % 8 - bx % 8 + (ax00 % 8 - ax00 % 64)) * 16 : ((ax00 // 8 + bx) % 8 - bx % 8 + (ax00 % 8 - ax00 % 64)) * 16 + 16, (((ax10 // 8 + by) % 8 - by % 8) * 8 + ax10 % 8) * 16 : (((ax10 // 8 + by) % 8 - by % 8) * 8 + ax10 % 8) * 16 + 16], [16, 16], dtype="float32", strides=[s1, s0], scope="shared.dyn", offset_factor=16) + T.evaluate(T.tvm_store_matrix_sync(src.data, 16, 16, 16, src.elem_offset // 256 + src.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), tgt.data, tgt.elem_offset, s1 * 16, 2, dtype="handle"), s1, "row_major", dtype="handle")) + for outer in T.serial(2): + for ty_1 in T.thread_binding(8, thread="threadIdx.y"): + for tx in T.thread_binding(32, thread="threadIdx.x"): + for vec in T.vectorized(4): + C[((bx % 8 + 0) * 8 + (ax00 % 64 + 0)) * 16 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 16 // 8 % 16, ((by % 8 + 0) * 8 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 16 % 8) * 16 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 16] = C_shared_dyn[(0 + 0) * 16 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 16 // 8 % 16, (0 * 8 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 16 % 8) * 16 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 16] + + +def _check(original, transformed): + mod = tvm.tir.transform.LowerAutoCopy()(original) + tvm.ir.assert_structural_equal(mod, transformed, True) + + +def test_coalesce_vectorize(): + _check(GlobalToShared, TransformedGlobalToShared) + + +def test_inverse(): + _check(SharedToGlobal, TransformedSharedToGlobal) + + +def test_local_stage(): + _check(GlobalToSharedWithLocalStage, TransformedGlobalToSharedWithLocalStage) + + +def test_rewrite_shared_to_wmma(): + _check(SharedToWmma, TransformedSharedToWmma) + + +def test_rewrite_wmma_to_shared(): + _check(WmmaToShared, TransformedWmmaToShared) + + +def test_rewrite_wmma_to_global(): + _check(WmmaToGlobal, TransformedWmmaToGlobal) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +def test_auto_padding(): + mod = tvm.tir.transform.LowerAutoCopy()(Transpose) + mod = tvm.tir.transform.FlattenBuffer()(mod) + verify_single_allocation(mod['main'].body, 16 * 130) + + +if __name__ == "__main__": + test_coalesce_vectorize() + test_inverse() + test_local_stage() + test_rewrite_shared_to_wmma() + test_rewrite_wmma_to_shared() + test_rewrite_wmma_to_global() + test_auto_padding() diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py new file mode 100644 index 000000000000..6217d2f0989a --- /dev/null +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T + + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +def tesd_renormalize_split_pattern(): + after = tvm.tir.transform.RenomalizeSplitPattern()(Before) + tvm.ir.assert_structural_equal(after, After) + + +if __name__ == "__main__": + tesd_renormalize_split_pattern()