Skip to content

Commit

Permalink
[TensorIR] Primitive "SetScope" (apache#9738)
Browse files Browse the repository at this point in the history
* Main code

* Reorder steps

* Unittests

* Docstring

* Check the input storage scope

* Docstring for `CheckStorageScope`

* Import header
  • Loading branch information
MasterJH5574 authored and ylc committed Jan 7, 2022
1 parent a833255 commit 76cbc6f
Show file tree
Hide file tree
Showing 15 changed files with 497 additions and 4 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,14 @@ class ScheduleNode : public runtime::Object {
*/
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/*!
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
* write-index
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
71 changes: 71 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,77 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:
self, block, buffer_index, axis, factor, offset
)

@type_checked
def set_scope(self, block: BlockRV, buffer_index: int, storage_scope: str) -> None:
"""Set the storage scope of a buffer, where the buffer is
specified by the a block and a write-index
Parameters
----------
block : BlockRV
The producer block of the buffer
buffer_index : int
The index of the buffer in block's write region
storage_scope : str
The storage scope to be set
Examples
--------
Before set_scope, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_set_scope(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do set_scope:
.. code-block:: python
sch = tir.Schedule(before_set_scope)
sch.set_scope(sch.get_block("B"), buffer_index=0, storage_scope="shared")
print(sch.mod["main"].script())
After applying set_scope, the IR becomes:
.. code-block:: python
@T.prim_func
def after_set_scope(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_shared[vi, vj] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B_shared[vi, vj] + T.float32(1)
Note
----
Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
_ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member
self, block, buffer_index, storage_scope
)

########## Schedule: Blockize & Tensorize ##########

########## Schedule: Annotation ##########
Expand Down
2 changes: 2 additions & 0 deletions src/tir/ir/functor_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/container/array.h>

/*!
* \file tir/ir/functor_common.h
* \brief Common utils for implementing functors
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,

/******** Misc ********/

/*!
* \brief Check whether the input storage scope string is valid. Throw an error if not.
* \param self The schedule state
* \param storage_scope The storage scope string to be checked
* \throw ScheduleError If the input storage scope is not valid
*/
void CheckStorageScope(const ScheduleState& self, String storage_scope);

/*!
* \brief Checks if a block could be successfully computed inline into its consumer
* \param self The schedule state
Expand Down
31 changes: 31 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1343,5 +1343,36 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
return GetRef<StmtSRef>(p);
}

/******** Storage Scope ********/

void CheckStorageScope(const ScheduleState& self, String storage_scope) {
class InvalidStorageScopeError : public ScheduleError {
public:
explicit InvalidStorageScopeError(IRModule mod, String storage_scope)
: mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {}

String FastErrorString() const final {
return "ScheduleError: The input storage scope is invalid";
}

String DetailRenderTemplate() const final {
return "The input storage scope \"" + storage_scope_ + "\" is invalid.";
}

Array<ObjectRef> LocationsOfInterest() const final { return {}; }
IRModule mod() const final { return mod_; }

private:
IRModule mod_;
String storage_scope_;
};

try {
runtime::StorageScope::Create(std::string(storage_scope));
} catch (...) {
throw InvalidStorageScopeError(self->mod, std::move(storage_scope));
}
}

} // namespace tir
} // namespace tvm
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,14 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde
this->state_->DebugVerify();
}

void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
const String& storage_scope) {
TVM_TIR_SCHEDULE_BEGIN();
tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Reduction ********/

BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ using StorageAlignAnnotation = Array<StorageAlignTuple>;
* more friendly memory access pattern. For example, we can set alignment to be factor=2,
* offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
* memory.
* \param self The state of the schedule
* \param block_sref The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param axis The dimension to be specified for alignment
Expand All @@ -337,6 +338,16 @@ using StorageAlignAnnotation = Array<StorageAlignTuple>;
*/
TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
int axis, int factor, int offset);
/*!
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
* write-index
* \param self The state of the schedule
* \param block_sref The sref of the producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& storage_scope);

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
Expand Down
Loading

0 comments on commit 76cbc6f

Please sign in to comment.