diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 1ac3f80ecf39..5e223c98d74d 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -219,6 +219,19 @@ class ScheduleNode : public runtime::Object { * \return The new loops after split */ virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; + /*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. + * 3) For every block under the loop nests, its block binding must be affine, and the block + * variables must be either data parallel or reduction. + * 4) No duplicated loops are allowed in the arguments. + * \param ordered_loop_rvs The loops in the new order + */ + virtual void Reorder(const Array& ordered_loop_rvs) = 0; /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 46e5fd6fddcb..c9cbf45b9055 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -442,6 +442,65 @@ def after_split(a: ty.handle, b: ty.handle) -> None: # that there is at most one None in `factors` return _ffi_api.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + def reorder(self, *ordered_loops: List[LoopRV]) -> None: + """ + Reorder a list of loops. It doesn't require the loops to be consecutive. + It requires: + 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + l_1 and l_n (which also indicates they are under the same scope). + 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. + 3) For every block under the loop nests, its block binding must be affine, and the block + variables must be either data parallel or reduction. + 4) No duplicated loops are allowed in the arguments. + + Parameters + ---------- + *ordered_loops : List[LoopRV] + The loops in the new order + + Examples + -------- + + Before reorder, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do reorder: + + .. code-block:: python + + sch = tir.Schedule(before_reorder) + i, j = sch.get_loops(sch.get_block("B")) + sch.reorder(j, i) + print(tvm.script.asscript(sch.mod["main"])) + + After applying reorder, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + # Here j and i are reordered + for j, i in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member + ########## Schedule: Manipulate ForKind ########## def parallel(self, loop: LoopRV) -> None: diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index b18090dd7215..084d0b0eec6a 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -346,6 +346,13 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, return CreateRV(results); } +void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::Reorder(state_, GetSRefs(ordered_loop_rvs)); + TVM_TIR_SCHEDULE_END("reorder", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Manipulate ForKind ********/ void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 2af4675ddcca..97819d63edb6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -81,6 +81,7 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; + void Reorder(const Array& ordered_loop_rvs) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; void Vectorize(const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 04c38f67da7d..2cf59f0b27c0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -63,6 +63,21 @@ TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, * \return The sref to the fused loop */ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); +/*! + * \brief Reorder a list of loops. It doesn't require the loops to be consecutive. + * It requires: + * 1) The loops are in the same chain. That means: the loops can be ordered to [l_1, l_2, ... , + * l_n] where l_i is an ancestor of l_{i+1} and there are only single-branch loops between + * l_1 and l_n (which also indicates they are under the same scope). + * 2) After reordering, the domain of an outer loop cannot depend on any of the inner loops. + * 3) For every block under the loop nests, its block binding must be affine, and the block + * variables must be either data parallel or reduction. + * 4) No duplicated loops are allowed in the arguments. + * \param self The state of the schedule + * \param ordered_loop_srefs An array of srefs which indicates the new order of loops + */ +TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); + /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index d1875df61ac7..7c2b61344427 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -131,6 +131,55 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Map loop_var2extent_; }; +class BlockPropertyError : public ScheduleError { + public: + /*! + * \brief Check that all the blocks under the specific stmt have affine bindings and only have + * data-parallel or reduction block iters + * \param self The state of the schedule + * \param sref The sref to the specific stmt + */ + static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, + const StmtSRefNode* sref) { + class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { + public: + explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state) : state_(state) {} + + private: + void VisitStmt_(const BlockNode* op) final { + for (const IterVar& iter_var : op->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + throw BlockPropertyError(state_->mod, GetRef(op)); + } + CheckAffineBinding(state_, GetRef(op)); + } + } + const ScheduleState& state_; + }; + + BlockIterTypeAndAffineBindingChecker checker(self); + checker(GetRef(sref->stmt)); + } + + explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block under the loops to be reordered have block iter type other " + "than data-parallel or reduction"; + } + + String DetailRenderTemplate() const final { + return "The block {0} under the loops to be reordered have block iter type other than " + "data-parallel or reduction"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + class HasAnnotationOrThreadBindingError : public ScheduleError { public: explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) @@ -253,6 +302,83 @@ class WrongFactorProductError : public ScheduleError { For loop_; }; +class LoopMultiAppearanceError : public ScheduleError { + public: + explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: Some loop appears in the input array for multiple times."; + } + + String DetailRenderTemplate() const final { + return "Loop {0} appears in the input array for multiple times."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class LoopsNotAChainError : public ScheduleError { + public: + enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; + + explicit LoopsNotAChainError(IRModule mod, Optional problematic_loop, ProblemKind kind) + : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} + + String FastErrorString() const final { return "ScheduleError: the loops are not in a chain"; } + + String DetailRenderTemplate() const final { + std::stringstream ss; + ss << "The loops are not in a chain because"; + if (kind_ == ProblemKind::kNotUnderAScope) { + ss << " they are not under the same scope."; + } else { + ss << " there is a non-single-branch stmt in between. Problematic stmt: {0}"; + } + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { + if (kind_ == ProblemKind::kNotUnderAScope) { + return {}; + } else { + ICHECK(problematic_loop_.defined()); + return {problematic_loop_.value()}; + } + } + + IRModule mod_; + Optional problematic_loop_; + ProblemKind kind_; +}; + +class DependentLoopError : public ScheduleError { + public: + explicit DependentLoopError(IRModule mod, For loop, String inner_var) + : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)) {} + + String FastErrorString() const final { + return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " + "in the new order"; + } + + String DetailRenderTemplate() const final { + return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + + " in the new order"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + String inner_var_; +}; + Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors) { // Invariance @@ -384,6 +510,182 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } +/*! + * \brief Collect an array of loop srefs into a set + * \param self The schedule state + * \param ordered_loop_srefs The array of loop srefs + * \return A set containing all loops in the array + * \throws ScheduleError If there are duplicate loops in the array + */ +std::unordered_set CollectLoopsIntoSet( + const ScheduleState& self, const Array& ordered_loop_srefs) { + std::unordered_set loop_srefs; + loop_srefs.reserve(ordered_loop_srefs.size()); + for (const StmtSRef& loop_sref : ordered_loop_srefs) { + auto inserted = loop_srefs.insert(loop_sref.get()); + if (!inserted.second) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + throw LoopMultiAppearanceError(self->mod, GetRef(loop)); + } + } + return loop_srefs; +} + +/*! + * \brief Get the top and bottom boundary of reorder range (which should be a chain) + * \param self The schedule state + * \param loop_srefs The set containing the srefs to the loops to be reordered + * \return A pair containing the top and bottom boundary of the reorder range + * \throws ScheduleError If the loops to be reordered is not in a chain + */ +std::pair GetBoundaryOfReorderRange( + const ScheduleState& self, const std::unordered_set& loop_srefs) { + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = *loop_srefs.begin(); + std::unordered_set visited; + bool scope_block_visited = false; + bool first_traversal = true; + for (const StmtSRefNode* loop_sref : loop_srefs) { + if (visited.count(loop_sref)) { + continue; + } + for (const StmtSRefNode* v = loop_sref;; v = v->parent) { + // Case 1. If `v` corresponds to a block, stop traversal. + if (v->stmt->IsInstance()) { + if (scope_block_visited) { + throw LoopsNotAChainError(self->mod, NullOpt, + LoopsNotAChainError::ProblemKind::kNotUnderAScope); + } + scope_block_visited = true; + break; + } + // Case 2. If `v` corresponds to a previously-visited loop, stop traversal and update + // `bottom`. + if (visited.count(v)) { + if (v != bottom) { + throw LoopsNotAChainError(self->mod, GetRef(v->stmt), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + bottom = loop_sref; + break; + } + // Case 3. Add `v` into `visited` + visited.insert(v); + // If it's the first traversal and the loop corresponding to `v` is in the input array, + // update `top`. + if (first_traversal && loop_srefs.count(v)) { + top = v; + } + } + first_traversal = false; + } + return std::make_pair(top, bottom); +} + +/*! + * \brief Get all the loops in the reorder range + * \param self The schedule state + * \param top The top boundary of the reorder range + * \param bottom The bottom boundary of the reorder range + * \return An array containing all the loops in the reorder range + * \throws ScheduleError If some loop in the reorder range is not single-branch + */ +std::vector GetLoopsInReorderRange(const ScheduleState& self, + const StmtSRefNode* top, + const StmtSRefNode* bottom) { + std::vector chain; + for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) { + const StmtSRefNode* parent_loop_sref = loop_sref->parent; + const ForNode* outer = parent_loop_sref->StmtAs(); + const ForNode* inner = loop_sref->StmtAs(); + ICHECK(outer != nullptr && inner != nullptr); + if (outer->body.get() != inner) { + throw LoopsNotAChainError(self->mod, GetRef(outer), + LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); + } + chain.push_back(loop_sref); + loop_sref = parent_loop_sref; + } + chain.push_back(top); + return chain; +} + +/*! + * \brief Construct a loop chain in the new order + * \param self The schedule state + * \param chain The loops in the reorder range + * \param ordered_loop_srefs The loop srefs to be reordered + * \param loop_srefs The set containing loop srefs to be reordered + * \return The new loop chain + * \throws ScheduleError If the domain of an outer loop depends on any of the inner loops after + * reordering + */ +For ConstructNewLoopChain(const ScheduleState& self, std::vector chain, + const Array& ordered_loop_srefs, + const std::unordered_set& loop_srefs) { + std::unordered_set inner_vars; + inner_vars.reserve(chain.size()); + For new_loop{nullptr}; + int index = static_cast(ordered_loop_srefs.size()) - 1; + for (const StmtSRefNode* loop_sref : chain) { + const ForNode* copy = nullptr; + if (loop_srefs.count(loop_sref)) { + copy = ordered_loop_srefs[index]->StmtAs(); + --index; + } else { + copy = loop_sref->StmtAs(); + } + ICHECK(copy != nullptr); + ObjectPtr n = make_object(*copy); + if (new_loop.defined()) { + n->body = new_loop; + } else { + n->body = loop_sref->StmtAs()->body; + } + const VarNode* used_var = nullptr; + auto f_contain = [&inner_vars, &used_var](const VarNode* var) { + if (inner_vars.count(var)) { + used_var = var; + return true; + } + return false; + }; + if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) { + throw DependentLoopError(self->mod, GetRef(copy), used_var->name_hint); + } + inner_vars.insert(copy->loop_var.get()); + new_loop = For(std::move(n)); + } + return new_loop; +} + +void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { + if (ordered_loop_srefs.size() <= 1) { + return; + } + // Step 1. Check uniqueness and collect the input loop srefs into a set + std::unordered_set loop_srefs = + CollectLoopsIntoSet(self, ordered_loop_srefs); + // Step 2. Gather loops to be reordered + // For each loop sref in the input sref array, traverse upwards along its parent pointer in the + // sref tree, and stop on either a block, or a previously-visited loop + // - the top of the reorder range is the last loop visited in the first traversal which exists in + // the input array + // - the bottom of the reorder range is the last loop in the input array which is not visited in + // the previous traversals + const StmtSRefNode* top = nullptr; + const StmtSRefNode* bottom = nullptr; + std::tie(top, bottom) = GetBoundaryOfReorderRange(self, loop_srefs); + // Step 3. Collect all loops in the chain and check the loops are single-branch + std::vector chain = GetLoopsInReorderRange(self, top, bottom); + // Step 4. Check the block below has all its block_var to be data-parallel or reduction, + // and the block has an affine binding. + BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, bottom); + // Step 5. Replace the original loops with the reordered loops and check that outer loop is + // not dependent on inner loop + For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); + self->Replace(GetRef(top), new_loop, {}); +} /******** Instruction Registration ********/ @@ -456,8 +758,40 @@ struct FuseTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct ReorderTraits : public UnpackedInstTraits { + static constexpr const char* kName = "Reorder"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + template + static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, + const Array& inputs) { + setter(delta, inputs); + } + + static void UnpackedApplyToSchedule(Schedule sch, Array loop_rvs) { + return sch->Reorder(loop_rvs); + } + + static String UnpackedAsPython(Array outputs, Array loop_rvs) { + PythonAPICall py("reorder"); + for (const String& loop_rv : loop_rvs) { + py.Input("", loop_rv); + } + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); +TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index f21a4c370a5b..29681fdf0926 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -125,6 +125,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") + .set_body_method(&ScheduleNode::Reorder); /******** (FFI) Manipulate ForKind ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") .set_body_method(&ScheduleNode::Parallel); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index e3f675e8628f..ae6a194b9888 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -99,6 +99,16 @@ Array TracedScheduleNode::Split(const LoopRV& loop_rv, return results; } +void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { + ConcreteScheduleNode::Reorder(ordered_loop_rvs); + + static const InstructionKind& kind = InstructionKind::Get("Reorder"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{ordered_loop_rvs.begin(), ordered_loop_rvs.end()}, + /*attrs=*/{}, + /*outputs=*/{})); +} + /******** Schedule: Manipulate ForKind ********/ void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index f5f31abe1556..11128ba32fad 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -54,6 +54,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; + void Reorder(const Array& ordered_loop_rvs) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; void Vectorize(const LoopRV& loop_rv) final; diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py new file mode 100644 index 000000000000..091a77df2030 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reorder.py @@ -0,0 +1,294 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_not_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in tir.grid(128, 128, 128, 8): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l * 16) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_dependent_loop(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i in tir.serial(0, 128): + for j, k, l in tir.grid(128, i, 128): + with tir.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_non_single_branch(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + for k in tir.serial(0, 128): + with tir.block([128], "B") as [vk]: + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_reordered(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_reordered2(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for k, j, i, l in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128, 128)) + for l, j, k, i in tir.grid(128, 128, 128, 128): + with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: + tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.bind(vl, l) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_reorder(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + for j, i in tir.grid(16, 16): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + for j, i in tir.grid(16, 16): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_reorder(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(l, i) + tvm.ir.assert_structural_equal(elementwise_reordered, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reorder2(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(k, i, l) + tvm.ir.assert_structural_equal(elementwise_reordered2, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise) + + +def test_reorder_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mask="all") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.reorder(j, i) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.reorder(j, i) + tvm.ir.assert_structural_equal(opaque_access_reorder, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=opaque_access) + + +def test_reorder_with_predicate(): + sch = tir.Schedule(elementwise_predicate, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + sch.reorder(l, i) + tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=elementwise_predicate) + + +def test_reorder_fail_with_multi_appearance_loops(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i, i) + + +def test_reorder_fail_with_non_single_branch_loop(): + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + sch = tir.Schedule(elementwise_non_single_branch, debug_mask="all") + block_b = sch.get_block("B") + block_c = sch.get_block("C") + i, j, k1 = sch.get_loops(block_b) + _, _, k2 = sch.get_loops(block_c) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k1, i, k2) + + +def test_reorder_fail_with_loops_not_under_same_scope(): + sch = tir.Schedule(elementwise_with_loops_not_same_scope, debug_mask="all") + block_b = sch.get_block("B") + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + k = sch.get_loops(block_b)[0] + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_wrong_block_var_type(): + sch = tir.Schedule(elementwise_with_wrong_block_var_type, debug_mask="all") + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(k, i) + + +def test_reorder_fail_with_dependent_loops(): + sch = tir.Schedule(elementwise_dependent_loop, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) + + +def test_reorder_fail_not_affine_bindings(): + sch = tir.Schedule(elementwise_not_affine, debug_mask="all") + block_b = sch.get_block("B") + i, j, k, l = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.reorder(l, i) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))