Skip to content

Commit

Permalink
[SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (#23)
Browse files Browse the repository at this point in the history
* Test initialization

* Fix a stupid bug of ReprPrinter

* Add SparseBlockRV

* Schedule: GetSparseBlock

* Schedule: Reorder
  • Loading branch information
MasterJH5574 authored and yzh119 committed Nov 22, 2021
1 parent c4e01bf commit 747d23c
Show file tree
Hide file tree
Showing 14 changed files with 1,082 additions and 174 deletions.
50 changes: 50 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>
#include <tvm/tir/sparse.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -85,6 +86,27 @@ using ExprRV = PrimExpr;

using ExprRVNode = PrimExprNode;

/**************** Random variable: SparseBlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR sparse block */
class SparseBlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.SparseBlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to SparseBlockRVNode
* \sa SparseBlockRVNode
*/
class SparseBlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL SparseBlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode);
};

/**************** The Schedule class ****************/

class Schedule;
Expand Down Expand Up @@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object {
* \return The corresponding expr
*/
virtual PrimExpr Get(const ExprRV& expr_rv) const = 0;
/*!
* \brief Get the sparse block corresponding to the specific random variable
* \param sp_block_rv The random variable to be looked up
* \return SparseBlock The corresponding sparse block
*/
virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
Expand Down Expand Up @@ -182,6 +210,11 @@ class ScheduleNode : public runtime::Object {
* \param expr_rv The random variable to be removed
*/
virtual void RemoveRV(const ExprRV& expr_rv) = 0;
/*!
* \brief Remove an sparse block random variable from the symbol table
* \param sp_block_rv The random variable to be removed
*/
virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0;

public:
/******** Schedule: Sampling ********/
Expand Down Expand Up @@ -453,6 +486,23 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
/******** Schedule: SparseTIR schedules ********/
/*!
* \brief Retrieve a sparse block in a specific function with its name
* \param name The name of the sparse block to be retrieved
* \param func_name The name of the function
* \return The sparse block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator
* dependency.
* \param block The block to be transformed
* \param new_order The new order of the sparse iterators, whose length should equal to the number
* of the input block's sparse iterators
*/
virtual void SparseReorder(const SparseBlockRV& block_rv, const Array<SpIterVar>& new_order) = 0;
};

/*!
Expand Down
74 changes: 67 additions & 7 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object
from tvm.tir import Block, For, IntImm, PrimFunc
from tvm.tir import Block, For, IntImm, PrimFunc, SparseBlock
from tvm.tir.sparse import SpIterVar

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -55,12 +56,23 @@ def __init__(self) -> None:
)


@_register_object("tir.SparseBlockRV")
class SparseBlockRV(Object):
"""A random variable that refers to a sparse block"""

def __init__(self) -> None:
"""Construct a new SparseBlockRV."""
self.__init_handle_by_constructor__(
_ffi_api.SparseBlockRV # type: ignore # pylint: disable=no-member
)


# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370
# This feature is not supported until python 3.10:
# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name
RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV, SparseBlockRV] # pylint: disable=invalid-name

# Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8
_ERROR_RENDER_LEVEL: Dict[str, int] = {
Expand Down Expand Up @@ -223,7 +235,7 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
Parameters
----------
rand_var : Union[ExprRV, BlockRV, LoopRV]
rand_var : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV]
The random variable to be evaluated
Returns
Expand All @@ -238,22 +250,23 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str:
def get(
self,
rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef],
) -> Optional[Union[int, Block, For]]:
) -> Optional[Union[int, Block, For, SparseBlock]]:
"""Returns:
- the corresponding Block that a BlockRV evaluates to;
- the corresponding For that a LoopRV evaluates to;
- the corresponding integer that a ExprRV evaluates to;
- the corresponding SparseBlock that a SparseBlockRV evaluates to;
- the corresponding Block that a block sref points to;
- the corresponding For that a loop sref points to;
Parameters
----------
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef]
rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV, StmtSRef]
The random variable / sref to be evaluated
Returns
-------
result : Optional[Union[int, Block, For]]
result : Optional[Union[int, Block, For, SparseBlock]]
The corresponding result
"""
if isinstance(rand_var_or_sref, StmtSRef):
Expand Down Expand Up @@ -289,7 +302,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
Parameters
----------
rand_var : Union[BlockRV, LoopRV, ExprRV]
rand_var : Union[BlockRV, LoopRV, ExprRV, SparseBlockRV]
The random variable to be removed
"""
return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member
Expand Down Expand Up @@ -1637,3 +1650,50 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:
def enter_postproc(self) -> None:
"""A no-op that marks the start of postprocessing phase of scheduling"""
_ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member

########## Schedule: SparseTIR schedules ##########

def get_sparse_block(
self,
name: str,
func_name: str = "main",
) -> SparseBlock:
"""Retrieve a sparse block in a specific function with its name
Parameters
----------
name : str
The name of the sparse block
func_name : str = "main"
The name of the function
Returns
-------
block : SparseBlockRV
The sparse block retrieved
IndexError is raised if 0 or multiple blocks exist with the specific name.
"""
return _ffi_api.ScheduleGetSparseBlock( # type: ignore # pylint: disable=no-member
self,
name,
func_name,
)

def sparse_reorder(self, block: SparseBlockRV, new_order: List[SpIterVar]) -> None:
"""Reorder a list of sparse iterators. It requires the new order to not break the iterator
dependency.
Parameters
----------
block : SparseBlockRV
The queried sparse block
new_order : List[SpIterVar]
The The new order of the sparse iterators, whose length should equal to the number
of the input block's sparse iterators
"""
return _ffi_api.ScheduleSparseReorder( # type: ignore # pylint: disable=no-member
self,
block,
new_order,
)
2 changes: 1 addition & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferStoreNode*>(node.get());
auto* op = static_cast<const SparseBufferStoreNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
Expand Down
Loading

0 comments on commit 747d23c

Please sign in to comment.