From 0d6675c02fb9503d8762a906c92bd7c1d838ceb1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 6 Nov 2021 11:36:45 +0800 Subject: [PATCH] [SparseTIR] SparseBlock on C++/Python side (#11) * Fix a bug in the last commit * SparseBlock on C++ & Python side --- include/tvm/tir/stmt.h | 77 ++++++++++++++++++++++++++++------------ python/tvm/tir/sparse.py | 2 +- python/tvm/tir/stmt.py | 53 +++++++++++++++++++++++++++ src/tir/ir/stmt.cc | 65 ++++++++++++++++++++++++++++++--- 4 files changed, 170 insertions(+), 27 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index b00d84f01b09..80a382f1f9da 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -327,28 +327,6 @@ class BufferStore : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); }; -/*! - * \brief Sparse Block node. - */ -class SparseBlockNode : public StmtNode { - public: - /*! \brief The sparse iteration variables of the block. */ - Array sp_iter_vars; - /*! \brief The sparse buffers defined in the block. */ - Array sp_buffers; - /*! \brief The body of the block */ - Stmt body; - - static constexpr const char* _type_key = "tir.SparseBlock"; - TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode); -}; - -class SparseBlock : public Stmt { - public: - TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode); -}; - - /*! * \brief Store value to the high dimension sparse buffer. * @@ -1300,6 +1278,61 @@ class BlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; +/*! + * \brief Sparse Block node. + */ +class SparseBlockNode : public StmtNode { + public: + /*! \brief The sparse iteration variables of the block. */ + Array sp_iter_vars; + /*! \brief The sparse buffers defined in the block. */ + Array sp_buffers; + /*! \brief The name of the block */ + String name; + /*! \brief The body of the block */ + Stmt body; + /*! \brief The init statement of the block */ + Optional init; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("sp_iter_vars", &sp_iter_vars); + v->Visit("sp_buffers", &sp_buffers); + v->Visit("name", &name); + v->Visit("body", &body); + v->Visit("init", &init); + } + + bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const { + return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) && + equal(name, other->name) && equal(body, other->body) && equal(init, other->init); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(sp_iter_vars); + hash_reduce(sp_buffers); + hash_reduce(name); + hash_reduce(body); + hash_reduce(init); + } + + static constexpr const char* _type_key = "tir.SparseBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode); +}; + +/*! + * \brief Managed reference to SparseBufferNode + * \sa SparseBufferNode + */ +class SparseBlock : public Stmt { + public: + TVM_DLL explicit SparseBlock(Array sp_iter_vars, Array sp_buffers, + String name, Stmt body, Optional init = NullOpt, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode); +}; + /*! \brief namespace of possible attribute sin AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 11302a14b1d8..4b0b857a8e6e 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -236,7 +236,7 @@ class SpIterVar(Object): SparseFixed = 2 SparseVariable = 3 - def __init__(self, var, max_extent, kind, axis=None): + def __init__(self, var, max_extent, kind, is_reduction, axis=None): self.__init_handle_by_constructor__( _ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index de200d5eabdd..84b91981ea89 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -36,6 +36,7 @@ from . import _ffi_api from .buffer import Buffer from .expr import IterVar +from .sparse import SpIterVar, SparseBuffer class Stmt(Object): @@ -614,6 +615,58 @@ def __init__( ) # type: ignore +@tvm._ffi.register_object("tir.SparseBlock") +class SparseBlock(Stmt): + """SparseBlock node. + + Parameters + ---------- + sp_iter_vars : List[SpIterVar] + The sparse iteration variables of the block. + + sp_buffers : List[SparseBuffer] + The sparse buffers defined in the block. + + name : str + The name of the block. + + body : Stmt + The body of the block. + + init : Optional[Stmt] + The init statement of the block. + + span : Optional[Span] + The location of this block in the source code. + """ + + sp_iter_vars: List[SpIterVar] + sp_buffers: List[SparseBuffer] + name: str + body: Stmt + init: Optional[Stmt] + span: Optional[Span] + + def __init__( + self, + sp_iter_vars: List[SpIterVar], + sp_buffers: List[SparseBuffer], + name: str, + body: Stmt, + init: Optional[Stmt] = None, + span: Optional[Span] = None, + ): + self.__init_handle_by_constructor__( + _ffi_api.SparseBlock, # type: ignore + sp_iter_vars, + sp_buffers, + name, + body, + init, + span, + ) # type: ignore + + @tvm._ffi.register_object("tir.BlockRealize") class BlockRealize(Stmt): """BlockRealize node. diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0cf4dc18b060..1cc80dd4d73c 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -883,17 +883,21 @@ void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) { } } -void PrintBlockBody(const BlockNode* op, ReprPrinter* p) { - // Print init - if (op->init.defined()) { +void PrintInitStmt(const Optional& init, ReprPrinter* p) { + if (init.defined()) { p->PrintIndent(); p->stream << "with init() {\n"; p->indent += 2; - p->Print(op->init.value()); + p->Print(init.value()); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; } +} + +void PrintBlockBody(const BlockNode* op, ReprPrinter* p) { + // Print init + PrintInitStmt(op->init, p); // Print body p->Print(op->body); } @@ -971,6 +975,59 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_buffers, String name, + Stmt body, Optional init, Span span) { + ObjectPtr node = make_object(); + node->sp_iter_vars = std::move(sp_iter_vars); + node->sp_buffers = std::move(sp_buffers); + node->name = std::move(name); + node->body = std::move(body); + node->init = std::move(init); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SparseBlock") + .set_body_typed([](Array sp_iter_vars, Array sp_buffers, String name, + Stmt body, Optional init, Span span) { + return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span); + }); + +TVM_REGISTER_NODE_TYPE(SparseBlockNode); + +void PrintSparseBlockTitle(const SparseBlockNode* op, ReprPrinter* p) { + p->stream << "sparse_block " << op->name << "("; + for (int i = 0; i < static_cast(op->sp_iter_vars.size()); ++i) { + p->Print(op->sp_iter_vars[i]); + if (i < static_cast(op->sp_iter_vars.size()) - 1) { + p->stream << ", "; + } + } + p->stream << ")"; +} + +void PrintSparseBlockBody(const SparseBlockNode* op, ReprPrinter* p) { + // Print init + PrintInitStmt(op->init, p); + // Print body + p->Print(op->body); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + PrintSparseBlockTitle(op, p); + p->stream << " {\n"; + p->indent += 2; + + PrintSparseBlockBody(op, p); + + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); + PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); return tir::Call(dtype, op, {}, span);