Skip to content

Commit

Permalink
[TensorIR] introduce Block and BlockRealize (#312)
Browse files Browse the repository at this point in the history
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
7 people committed Mar 1, 2021
1 parent 2673309 commit a61828c
Show file tree
Hide file tree
Showing 8 changed files with 869 additions and 1 deletion.
248 changes: 247 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ class For : public Stmt {
};

/*!
* \brief A prefetch hint for abuffer
* \brief A prefetch hint for a buffer
*/
class PrefetchNode : public StmtNode {
public:
Expand Down Expand Up @@ -905,6 +905,252 @@ class Prefetch : public Stmt {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};

/*!
* \brief Representing the region of multi-dimensional buffer access.
*/
class BufferRegionNode : public Object {
public:
/*! \brief The buffer of the buffer region. */
Buffer buffer;
/*! \brief The region array of the buffer region. */
Array<Range> region;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("region", &region);
}

bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(region, other->region);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(region);
}

static constexpr const char* _type_key = "tir.BufferRegion";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object);
};

/*!
* \brief Managed reference to BufferRegionNode.
* \sa BufferRegionNode
*/
class BufferRegion : public ObjectRef {
public:
TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);

/*!
* \brief Create a BufferRegion which is full region of the given buffer..
* \param buffer The buffer to generate full BufferRegion.
* \return The BufferRegion which covers all region of the given buffer
*/
TVM_DLL static BufferRegion FullRegion(Buffer buffer);

TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
};

/*!
* \brief Match introduces a constraint that the source buffer region can be remapped to the data
* layout specified by the buffer field. The constraint can be checked in later part of lowering (or
* optionally during runtime).
*
* MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in
* low-level hardware primitives in the IR and defer the check after the sequence of
* transformations.
*/
class MatchBufferRegionNode : public Object {
public:
/*! \brief The target buffer. */
Buffer buffer;
/*! \brief The source buffer region. */
BufferRegion source;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("source", &source);
}

bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(source, other->source);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(source);
}

static constexpr const char* _type_key = "tir.MatchBufferRegion";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object);
};

/*!
* \brief Managed reference to MatchBufferRegionNode.
* \sa MatchBufferRegionNode
*/
class MatchBufferRegion : public ObjectRef {
public:
TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);

TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
};

/*!
* \brief A block is a basic schedule unit in TIR.
* \note Block's body is parameterized by iter vars.
* \code
*
* with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* tir.bind(v0, value0)
* tir.bind(v1, value1)
* ...
* tir.reads([buffer0[start:end, ...], ...])
* tir.writes([buffer1[start:end, ...], ...])
* tir.where(predicate)
* buffer2 = tir.alloc_buffer(shape, dtype)
* buffer3 = tir.match_buffer(source_buffer[start:end, ...])
* tir.attr({attr_key: attr_value, ...})
* with tir.init():
* // init body
* // body
*
* \endcode
*/
class BlockNode : public StmtNode {
public:
/*! \brief The variables of the block. */
Array<IterVar> iter_vars;
/*! \brief The read buffer regions of the block. */
Array<BufferRegion> reads;
/*! \brief The write buffer regions of the block. */
Array<BufferRegion> writes;
/*! \brief The name_hint of the block. */
String name_hint;
/*! \brief The body of the block. */
Stmt body;
/*!
* \brief The init statement is executed during the first iteration of reduction loops in a
* reduction block. The optional init field allows us to represent initialization and
* reduction update in a single block and transform them collectively.
* We also provide primitives to decompose the init into a separate block during scheduling.
* Init field is `NullOpt` if there is no reduction iter_vars
*/
Optional<Stmt> init;
/*! \brief The buffer allocated in the block. */
Array<Buffer> alloc_buffers;
/*! \brief The match buffer regions. */
Array<MatchBufferRegion> match_buffers;
/*! \brief The annotation of the block. */
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_vars", &iter_vars);
v->Visit("reads", &reads);
v->Visit("writes", &writes);
v->Visit("name_hint", &name_hint);
v->Visit("body", &body);
v->Visit("init", &init);
v->Visit("alloc_buffers", &alloc_buffers);
v->Visit("match_buffers", &match_buffers);
v->Visit("annotations", &annotations);
}

bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
// Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
return equal.DefEqual(iter_vars, other->iter_vars) &&
equal(alloc_buffers, other->alloc_buffers) &&
equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(iter_vars);
hash_reduce(alloc_buffers);
hash_reduce(match_buffers);
hash_reduce(reads);
hash_reduce(writes);
hash_reduce(body);
hash_reduce(init);
hash_reduce(annotations);
}

static constexpr const char* _type_key = "tir.Block";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode);
};

/*!
* \brief Managed reference to BlockNode.
* \sa BlockNode
*/
class Block : public Stmt {
public:
TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
Array<BufferRegion> writes, String name_hint, Stmt body,
Optional<Stmt> init = NullOpt,
Array<Buffer> alloc_buffers = Array<Buffer>(),
Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode);
};

/*!
* \brief A block realization node represents execution of the block at the binding values.
*/
class BlockRealizeNode : public StmtNode {
public:
/*! \brief The corresponding values of the iter vars. */
Array<PrimExpr> iter_values;
/*!
* \brief The predicate of the block realization, the block will only be executed when the
* predicate is true.
*/
PrimExpr predicate;
/*! \brief The block to be realized. */
Block block;

void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_values", &iter_values);
v->Visit("predicate", &predicate);
v->Visit("block", &block);
}

bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
equal(block, other->block);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(iter_values);
hash_reduce(predicate);
hash_reduce(block);
}

static constexpr const char* _type_key = "tir.BlockRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode);
};

/*!
* \brief Managed reference to BlockRealizeNode
* \sa BlockRealizeNode
*/
class BlockRealize : public Stmt {
public:
TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
};

/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
Expand All @@ -119,6 +121,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
return vtable;
}
};
Expand Down Expand Up @@ -158,6 +162,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const PrefetchNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const BlockNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
};

/*!
Expand Down Expand Up @@ -249,6 +255,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const PrefetchNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const EvaluateNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc

Expand Down
Loading

0 comments on commit a61828c

Please sign in to comment.