From 6b3ec497ff840384316152d7d1c4abdda5fc50d6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 16 Dec 2021 01:46:00 +0800 Subject: [PATCH] [TensorIR] Primitive "SetScope" (#9738) * Main code * Reorder steps * Unittests * Docstring * Check the input storage scope * Docstring for `CheckStorageScope` * Import header --- include/tvm/tir/schedule/schedule.h | 8 + python/tvm/tir/schedule/schedule.py | 71 +++++++ src/tir/ir/functor_common.h | 2 + src/tir/schedule/analysis.h | 8 + src/tir/schedule/analysis/analysis.cc | 31 +++ src/tir/schedule/concrete_schedule.cc | 8 + src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 11 + src/tir/schedule/primitive/block_annotate.cc | 192 +++++++++++++++++- .../schedule/primitive/cache_read_write.cc | 7 + src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 11 + src/tir/schedule/traced_schedule.h | 1 + .../test_tir_schedule_cache_read_write.py | 14 ++ .../unittest/test_tir_schedule_set_scope.py | 134 ++++++++++++ 15 files changed, 497 insertions(+), 4 deletions(-) create mode 100644 tests/python/unittest/test_tir_schedule_set_scope.py diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 44592b0832d0..57e6fb961a9b 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -448,6 +448,14 @@ class ScheduleNode : public runtime::Object { */ virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) = 0; + /*! + * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a + * write-index + * \param block_rv The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param storage_scope The storage scope to be set + */ + virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 3a1d71495672..2e34b33ef1c3 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1660,6 +1660,77 @@ def after_storage_align(a: T.handle, c: T.handle) -> None: self, block, buffer_index, axis, factor, offset ) + @type_checked + def set_scope(self, block: BlockRV, buffer_index: int, storage_scope: str) -> None: + """Set the storage scope of a buffer, where the buffer is + specified by the a block and a write-index + + Parameters + ---------- + block : BlockRV + The producer block of the buffer + buffer_index : int + The index of the buffer in block's write region + storage_scope : str + The storage scope to be set + + Examples + -------- + + Before set_scope, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_set_scope( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"] + ) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do set_scope: + + .. code-block:: python + + sch = tir.Schedule(before_set_scope) + sch.set_scope(sch.get_block("B"), buffer_index=0, storage_scope="shared") + print(sch.mod["main"].script()) + + After applying set_scope, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_set_scope( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"] + ) -> None: + B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_shared[vi, vj] + T.float32(1) + + Note + ---- + Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + """ + _ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member + self, block, buffer_index, storage_scope + ) + ########## Schedule: Blockize & Tensorize ########## ########## Schedule: Annotation ########## diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 9ed911f6b782..8b5a361a37c6 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + /*! * \file tir/ir/functor_common.h * \brief Common utils for implementing functors diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 82f4afa7a24c..ae72d592339f 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -395,6 +395,14 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, /******** Misc ********/ +/*! + * \brief Check whether the input storage scope string is valid. Throw an error if not. + * \param self The schedule state + * \param storage_scope The storage scope string to be checked + * \throw ScheduleError If the input storage scope is not valid + */ +void CheckStorageScope(const ScheduleState& self, String storage_scope); + /*! * \brief Checks if a block could be successfully computed inline into its consumer * \param self The schedule state diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 417c6331f496..6d744a66b498 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1343,5 +1343,36 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { return GetRef(p); } +/******** Storage Scope ********/ + +void CheckStorageScope(const ScheduleState& self, String storage_scope) { + class InvalidStorageScopeError : public ScheduleError { + public: + explicit InvalidStorageScopeError(IRModule mod, String storage_scope) + : mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {} + + String FastErrorString() const final { + return "ScheduleError: The input storage scope is invalid"; + } + + String DetailRenderTemplate() const final { + return "The input storage scope \"" + storage_scope_ + "\" is invalid."; + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + String storage_scope_; + }; + + try { + runtime::StorageScope::Create(std::string(storage_scope)); + } catch (...) { + throw InvalidStorageScopeError(self->mod, std::move(storage_scope)); + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 4db4cd4ba1c8..65bd50788685 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -541,6 +541,14 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde this->state_->DebugVerify(); } +void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, + const String& storage_scope) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope); + TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: Reduction ********/ BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index e4df5f893ae9..f9404ae7dc8c 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -118,6 +118,7 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; + void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index cc7e44d4df9e..520e70bf2475 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -329,6 +329,7 @@ using StorageAlignAnnotation = Array; * more friendly memory access pattern. For example, we can set alignment to be factor=2, * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared * memory. + * \param self The state of the schedule * \param block_sref The producer block of the buffer * \param buffer_index The index of the buffer in block's write region * \param axis The dimension to be specified for alignment @@ -337,6 +338,16 @@ using StorageAlignAnnotation = Array; */ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset); +/*! + * \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a + * write-index + * \param self The state of the schedule + * \param block_sref The sref of the producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param storage_scope The storage scope to be set + */ +TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& storage_scope); /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 894f7a9d027e..181e5a6cfa69 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../ir/functor_common.h" #include "../utils.h" namespace tvm { @@ -118,14 +119,16 @@ class NonAllocatedBufferError : public ScheduleError { return os.str(); } - static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& block_sref, - const Buffer& buffer) { + static StmtSRef CheckAndGetBufferAllocationSite(const IRModule& mod, const StmtSRef& block_sref, + const Buffer& buffer) { Optional defining_site_sref; bool is_alloc; std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, buffer); - if (!defining_site_sref || !is_alloc) { + if (!defining_site_sref.defined() || !is_alloc) { throw NonAllocatedBufferError(mod, buffer); } + + return defining_site_sref.value(); } Array LocationsOfInterest() const final { return {}; } @@ -233,6 +236,133 @@ class StorageAlignInvalidAnnotationError : public ScheduleError { Block block_; }; +/*! + * \brief A helper mutator which recursively mutates the old buffer's storage scope and collects + * the block sref reuse information for the following replacement. + */ +class StorageScopeMutator : StmtExprMutator { + public: + /*! + * \param allocate_site The block where `old_buffer` was allocated. + * \param old_buffer The old buffer + * \param storage_scope The storage scope to be set + * \param block_sref_reuse The block sref reuse map to be updated + * \return The new block after the mutation + */ + static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, + const String& storage_scope, Map* block_sref_reuse) { + Buffer new_buffer = WithScope(old_buffer, storage_scope); + StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse); + Stmt new_block = mutator.VisitStmt(allocate_site); + return Downcast(new_block); + } + + private: + StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope, + Map* block_sref_reuse) + : storage_scope_(std::move(storage_scope)), block_sref_reuse_(block_sref_reuse) { + buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer); + } + + PrimExpr VisitExpr_(const VarNode* var) final { + auto it = buffer_var_map_.find(var); + return it != buffer_var_map_.end() ? it->second->data : GetRef(var); + } + + PrimExpr VisitExpr_(const BufferLoadNode* load) final { + BufferLoad res = Downcast(ExprMutator::VisitExpr_(load)); + + auto it = buffer_var_map_.find(res->buffer->data.get()); + if (it != buffer_var_map_.end()) { + ObjectPtr ptr = make_object(*res.get()); + ptr->buffer = it->second; + return PrimExpr(ptr); + } else { + return res; + } + } + + Stmt VisitStmt_(const BufferStoreNode* store) final { + BufferStore res = Downcast(StmtMutator::VisitStmt_(store)); + + auto it = buffer_var_map_.find(res->buffer->data.get()); + if (it != buffer_var_map_.end()) { + ObjectPtr ptr = make_object(*res.get()); + ptr->buffer = it->second; + return Stmt(ptr); + } else { + return res; + } + } + + Stmt VisitStmt_(const BlockNode* block) final { + // To reduce the number of blocks in block sref reuse map, we check whether the block is really + // mutated (i.e., the old buffer appears in the block). If so, we return the block after + // mutation. Otherwise we just return the original block. + + // Define the mutation functions. + auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { + auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); + if (it != buffer_var_map_.end()) { + Buffer new_target_buffer = WithScope(match_buffer->buffer, storage_scope_); + buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer; + return MatchBufferRegion(new_target_buffer, + BufferRegion(it->second, match_buffer->source->region)); + } else { + return match_buffer; + } + }; + auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { + auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); + return it == buffer_var_map_.end() ? buffer_region + : BufferRegion(it->second, buffer_region->region); + }; + auto f_mutate_alloc_buffers = [this](const Buffer& buffer) { + auto it = buffer_var_map_.find(buffer->data.get()); + return it == buffer_var_map_.end() ? buffer : it->second; + }; + + // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion, + // the storage scope of the target buffer also needs to be set. + Array match_buffers = + MutateArray(block->match_buffers, f_mutate_match_buffers); + // Step 2. Mutate the read/write region. + Array reads = MutateArray(block->reads, f_mutate_read_write_region); + Array writes = MutateArray(block->writes, f_mutate_read_write_region); + // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block. + Array alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers); + // Step 4. Recursively mutate the block. + Block mutated_block = Downcast(StmtMutator::VisitStmt_(block)); + + if (mutated_block.get() == block && reads.same_as(mutated_block->reads) && + writes.same_as(mutated_block->writes) && + alloc_buffers.same_as(mutated_block->alloc_buffers) && + match_buffers.same_as(mutated_block->match_buffers)) { + return GetRef(block); + } else { + ObjectPtr n = CopyOnWrite(mutated_block.get()); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->alloc_buffers = std::move(alloc_buffers); + n->match_buffers = std::move(match_buffers); + + Block new_block(n); + block_sref_reuse_->Set(GetRef(block), new_block); + return new_block; + } + } + + /*! \brief The storage scope to be set. */ + String storage_scope_; + /*! + * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in + * MatchBufferRegion. + */ + std::unordered_map buffer_var_map_; + /*! \brief The block sref reuse map for the following replacement */ + Map* block_sref_reuse_; +}; + void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, int factor, int offset) { const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); @@ -240,7 +370,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, /*is_write=*/true); StorageAlignInvalidFactorError::Check(self->mod, factor); axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); - NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); + NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); // Step 1: Get existing or create new annotation value. StorageAlignAnnotation storage_align_annotation = @@ -269,6 +399,32 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); } +void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + const String& storage_scope) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer buffer = GetNthAccessBuffer(self, GetRef(block), buffer_index, true); + + // Step 1. If `storage_scope` equals the original storage scope of the buffer, just return. + if (buffer.scope() == storage_scope) { + return; + } + + // Step 2. Throw an error if the input storage scope is invalid. + CheckStorageScope(self, storage_scope); + + // Step 3. Get the allocation site of the target buffer. + StmtSRef alloc_site_sref = + NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer); + const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site, alloc_site_sref); + + // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given + // storage scope. In the meanwhile, collect the block sref reuse information. + Map block_reuse_map; + Block new_block = StorageScopeMutator::Mutate(GetRef(alloc_site), buffer, storage_scope, + &block_reuse_map); + self->Replace(alloc_site_sref, new_block, block_reuse_map); +} + /******** InstructionKind Registration ********/ struct StorageAlignTraits : public UnpackedInstTraits { @@ -301,7 +457,35 @@ struct StorageAlignTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct SetScopeTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SetScope"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + String storage_scope) { + return sch->SetScope(block_rv, buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + String storage_scope) { + PythonAPICall py("set_scope"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("storage_scope", storage_scope); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); +TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 8628cc3c0791..159171ecae31 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -624,6 +624,9 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff * - Copy the buffer with the consumed region. */ + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + // Step 1. Check index, getting the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Buffer read_buffer = @@ -692,6 +695,10 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu * - Find the lowest ancestor of the block and ANY ONE of the producer blocks. * - Copy the buffer with the consumed region. */ + + // Step 0. Check the input storage scope. + CheckStorageScope(self, storage_scope); + // Step 1. Checking index, getting the target buffer and the parent scope const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Buffer write_buffer = diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a411e40b13b6..c375184830a6 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -180,6 +180,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") /******** (FFI) Block annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") + .set_body_method(&ScheduleNode::SetScope); /******** (FFI) Blockize & Tensorize ********/ /******** (FFI) Annotation ********/ /******** (FFI) Misc ********/ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 4a028d1dad5c..f18f9ade436c 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -329,6 +329,17 @@ void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, /*outputs=*/{})); } +void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index, + const String& storage_scope) { + ConcreteScheduleNode::SetScope(block_rv, buffer_index, storage_scope); + static const InstructionKind& kind = InstructionKind::Get("SetScope"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), storage_scope}, + /*outputs=*/{})); +} + /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ac36b9ca06a9..aa4bbb2e0099 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -84,6 +84,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) final; + void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/tests/python/unittest/test_tir_schedule_cache_read_write.py b/tests/python/unittest/test_tir_schedule_cache_read_write.py index 853f44affe5d..22f26ce0318a 100644 --- a/tests/python/unittest/test_tir_schedule_cache_read_write.py +++ b/tests/python/unittest/test_tir_schedule_cache_read_write.py @@ -684,6 +684,13 @@ def test_cache_read_fail_index_out_of_bound(): sch.cache_read(block_b, 1, "global") +def test_cache_read_fail_invalid_storage_scope(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_read(block_b, 0, "test_scope") + + ########## Testcases for cache_write ########## @@ -759,5 +766,12 @@ def test_cache_write_fail_index_out_of_bound(): sch.cache_write(block_b, 1, "global") +def test_cache_write_fail_invalid_storage_scope(): + sch = tir.Schedule(elementwise, debug_mask="all") + block_b = sch.get_block("B") + with pytest.raises(tvm.tir.ScheduleError): + sch.cache_write(block_b, 0, "test_scope") + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_set_scope.py b/tests/python/unittest/test_tir_schedule_set_scope.py new file mode 100644 index 000000000000..29c4880f7762 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_set_scope.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + +@T.prim_func +def element_wise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 + + +@T.prim_func +def element_wise_set_scope(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_shared[vi, vj] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B_shared[vi, vj] + T.float32(1) + + +@T.prim_func +def element_wise_subregion_match(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B = T.alloc_buffer((128, 128), dtype="float32") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion0 = T.match_buffer(B[i, j], [], offset_factor=1) + B_subregion0[()] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion1 = T.match_buffer(B[i, j], [], offset_factor=1) + C[vi, vj] = B_subregion1[()] + 1.0 + + +@T.prim_func +def element_wise_subregion_match_set_scope(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: + B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion0_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1) + B_subregion0_shared[()] = A[vi, vj] * T.float32(2) + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + B_subregion1_shared = T.match_buffer(B_shared[i, j], [], dtype="float32", scope="shared", offset_factor=1) + C[vi, vj] = B_subregion1_shared[()] + T.float32(1) + + +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +def test_set_scope(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + s.set_scope(s.get_block("B"), 0, "shared") + tvm.ir.assert_structural_equal(element_wise_set_scope, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_set_scope_fail_on_output_buffer(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + s.set_scope(s.get_block("C"), 0, "shared") + + +def test_set_scope_fail_on_index_out_of_bound(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + s.set_scope(s.get_block("B"), 1, "shared") + with pytest.raises(tvm.tir.ScheduleError): + s.set_scope(s.get_block("B"), -1, "shared") + + +def test_set_scope_fail_on_invalid_scope(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + with pytest.raises(tvm.tir.ScheduleError): + s.set_scope(s.get_block("B"), 0, "test_scope") + + +def test_set_scope_subregion(): + func = element_wise_subregion_match + s = tir.Schedule(func, debug_mask='all') + s.set_scope(s.get_block("B"), 0, "shared") + tvm.ir.assert_structural_equal(element_wise_subregion_match_set_scope, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +if __name__ == "__main__": + test_set_scope() + test_set_scope_fail_on_output_buffer() + test_set_scope_fail_on_index_out_of_bound() + test_set_scope_fail_on_invalid_scope() + test_set_scope_subregion()