-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)
* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor * retrigger * [IRFunctor] Migrated CodegenC * [IRFUNCTOR] Migrate CodeGenLLVM * [IRFunctor] Migrate canonical * [IRFunctor] Migrate vectorize * [IRFunctor] migrate CodeGenStackVM
- Loading branch information
Showing
25 changed files
with
2,028 additions
and
1,471 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file ir_functor_ext.h | ||
* \brief More powerful Visitor that allows define function signatures. | ||
*/ | ||
#ifndef TVM_IR_FUNCTOR_EXT_H_ | ||
#define TVM_IR_FUNCTOR_EXT_H_ | ||
|
||
#include <tvm/ir_functor.h> | ||
#include "./ir.h" | ||
|
||
namespace tvm { | ||
namespace ir { | ||
|
||
/*! | ||
* \brief A dynamical functor that dispatches on in the first Expr argument. | ||
* You can use this as a more powerful Visitor, since it allows you to | ||
* define function signatures of Visit Function. | ||
* | ||
* \code | ||
* // A functor that set variable to b. and calculate results. | ||
* class MyExprFunctor | ||
* : public ir::ExprFunctor<int(const Expr&, int)> { | ||
* public: | ||
* int VisitExpr_(const Variable* op, int b) final { | ||
* return b; | ||
* } | ||
* int VisitExpr_(const IntImm* op, int b) final { | ||
* return op->value; | ||
* } | ||
* int VisitExpr_(const Add* op, int b) final { | ||
* return Visit(op->a, b) + Visit(op->b, b); | ||
* } | ||
* }; | ||
* MyExprFunctor f; | ||
* Var x("x"); | ||
* CHECK_EQ(f(x + 1, 2), 3); | ||
* \endcode | ||
* | ||
* \note Why do we need this more powerful Functor: | ||
* | ||
* We often need to implement a transformer tasks. | ||
* Say we want to take Expr and transform it to some analysis result, | ||
* This easily be done incorrectly using plain Visitor. See IRVisitor's | ||
* document for possible error cases. | ||
* | ||
* \tparam FType function signiture | ||
* This type if only defined for FType with function signiture R(const Expr&, Args...) | ||
*/ | ||
template<typename FType> | ||
class ExprFunctor; | ||
/*! | ||
* \brief Same as ExprFunctor except it is applied on statements | ||
* \tparam FType The function signature. | ||
*/ | ||
template<typename FType> | ||
class StmtFunctor; | ||
|
||
// functions to be overriden. | ||
#define EXPR_FUNCTOR_DEFAULT { \ | ||
return VisitExprDefault_(op, std::forward<Args>(args)...); \ | ||
} | ||
#define STMT_FUNCTOR_DEFAULT { \ | ||
return VisitStmtDefault_(op, std::forward<Args>(args)...); \ | ||
} | ||
|
||
#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ | ||
vtable.template set_dispatch<OP>( \ | ||
[](const NodeRef& n, TSelf* self, Args... args) { \ | ||
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \ | ||
std::forward<Args>(args)...); \ | ||
}); \ | ||
|
||
#define IR_STMT_FUNCTOR_DISPATCH(OP) \ | ||
vtable.template set_dispatch<OP>( \ | ||
[](const NodeRef& n, TSelf* self, Args... args) { \ | ||
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \ | ||
std::forward<Args>(args)...); \ | ||
}); \ | ||
|
||
template<typename R, typename ...Args> | ||
class ExprFunctor<R(const Expr& n, Args...)> { | ||
private: | ||
using TSelf = ExprFunctor<R(const Expr& n, Args...)>; | ||
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>; | ||
|
||
public: | ||
/*! \brief the result type of this functor */ | ||
using result_type = R; | ||
/*! \brief virtual destructor */ | ||
virtual ~ExprFunctor() {} | ||
/*! | ||
* \brief Same as call. | ||
* \param n The expression node. | ||
* \param args Additional arguments. | ||
* \return The result of the call | ||
*/ | ||
R operator()(const Expr& n, Args... args) { | ||
return VisitExpr(n, std::forward<Args>(args)...); | ||
} | ||
/*! | ||
* \brief The functor call. | ||
* \param n The expression node. | ||
* \param args Additional arguments. | ||
* \return The result of the call | ||
*/ | ||
virtual R VisitExpr(const Expr& n, Args... args) { | ||
static FType vtable = InitVTable(); | ||
return vtable(n, this, std::forward<Args>(args)...); | ||
} | ||
// Functions that can be overriden by subclass | ||
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT; | ||
virtual R VisitExprDefault_(const Node* op, Args ...) { | ||
LOG(FATAL) << "Do not have a default for " << op->type_key(); | ||
return R(); | ||
} | ||
|
||
private: | ||
// initialize the vtable. | ||
static FType InitVTable() { | ||
FType vtable; | ||
// Set dispatch | ||
IR_EXPR_FUNCTOR_DISPATCH(Variable); | ||
IR_EXPR_FUNCTOR_DISPATCH(Load); | ||
IR_EXPR_FUNCTOR_DISPATCH(Let); | ||
IR_EXPR_FUNCTOR_DISPATCH(Call); | ||
IR_EXPR_FUNCTOR_DISPATCH(Add); | ||
IR_EXPR_FUNCTOR_DISPATCH(Sub); | ||
IR_EXPR_FUNCTOR_DISPATCH(Mul); | ||
IR_EXPR_FUNCTOR_DISPATCH(Div); | ||
IR_EXPR_FUNCTOR_DISPATCH(Mod); | ||
IR_EXPR_FUNCTOR_DISPATCH(Min); | ||
IR_EXPR_FUNCTOR_DISPATCH(Max); | ||
IR_EXPR_FUNCTOR_DISPATCH(EQ); | ||
IR_EXPR_FUNCTOR_DISPATCH(NE); | ||
IR_EXPR_FUNCTOR_DISPATCH(LT); | ||
IR_EXPR_FUNCTOR_DISPATCH(LE); | ||
IR_EXPR_FUNCTOR_DISPATCH(GT); | ||
IR_EXPR_FUNCTOR_DISPATCH(GE); | ||
IR_EXPR_FUNCTOR_DISPATCH(And); | ||
IR_EXPR_FUNCTOR_DISPATCH(Or); | ||
IR_EXPR_FUNCTOR_DISPATCH(Reduce); | ||
IR_EXPR_FUNCTOR_DISPATCH(Cast); | ||
IR_EXPR_FUNCTOR_DISPATCH(Not); | ||
IR_EXPR_FUNCTOR_DISPATCH(Select); | ||
IR_EXPR_FUNCTOR_DISPATCH(Ramp); | ||
IR_EXPR_FUNCTOR_DISPATCH(Broadcast); | ||
IR_EXPR_FUNCTOR_DISPATCH(IntImm); | ||
IR_EXPR_FUNCTOR_DISPATCH(UIntImm); | ||
IR_EXPR_FUNCTOR_DISPATCH(FloatImm); | ||
IR_EXPR_FUNCTOR_DISPATCH(StringImm); | ||
return vtable; | ||
} | ||
}; | ||
|
||
template<typename R, typename ...Args> | ||
class StmtFunctor<R(const Stmt& n, Args... args)> { | ||
private: | ||
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>; | ||
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>; | ||
|
||
public: | ||
/*! \brief the result type of this functor */ | ||
using result_type = R; | ||
/*! \brief virtual destructor */ | ||
virtual ~StmtFunctor() {} | ||
/*! | ||
* \brief Same as call. | ||
* \param n The stmt node. | ||
* \param args Additional arguments. | ||
* \return The result of the call | ||
*/ | ||
R operator()(const Stmt& n, Args... args) { | ||
return VisitStmt(n, std::forward<Args>(args)...); | ||
} | ||
/*! | ||
* \brief The functor call. | ||
* \param n The stmt node. | ||
* \param args Additional arguments. | ||
* \return The result of the call | ||
*/ | ||
virtual R VisitStmt(const Stmt& n, Args... args) { | ||
static FType vtable = InitVTable(); | ||
return vtable(n, this, std::forward<Args>(args)...); | ||
} | ||
// Functions that can be overriden by subclass | ||
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT; | ||
virtual R VisitStmtDefault_(const Node* op, Args ...) { | ||
LOG(FATAL) << "Do not have a default for " << op->type_key(); | ||
return R(); | ||
} | ||
|
||
private: | ||
// initialize the vtable. | ||
static FType InitVTable() { | ||
FType vtable; | ||
IR_STMT_FUNCTOR_DISPATCH(LetStmt); | ||
IR_STMT_FUNCTOR_DISPATCH(AttrStmt); | ||
IR_STMT_FUNCTOR_DISPATCH(IfThenElse); | ||
IR_STMT_FUNCTOR_DISPATCH(For); | ||
IR_STMT_FUNCTOR_DISPATCH(Allocate); | ||
IR_STMT_FUNCTOR_DISPATCH(Store); | ||
IR_STMT_FUNCTOR_DISPATCH(Free); | ||
IR_STMT_FUNCTOR_DISPATCH(AssertStmt); | ||
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer); | ||
IR_STMT_FUNCTOR_DISPATCH(Provide); | ||
IR_STMT_FUNCTOR_DISPATCH(Realize); | ||
IR_STMT_FUNCTOR_DISPATCH(Block); | ||
IR_STMT_FUNCTOR_DISPATCH(Evaluate); | ||
return vtable; | ||
} | ||
}; | ||
|
||
#undef IR_STMT_FUNCTOR_DISPATCH | ||
#undef IR_EXPR_FUNCTOR_DISPATCH | ||
#undef EXPR_FUNCTOR_DEFAULT | ||
#undef STMT_FUNCTOR_DEFAULT | ||
|
||
} // namespace ir | ||
} // namespace tvm | ||
#endif // TVM_IR_FUNCTOR_EXT_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.