Skip to content

Commit

Permalink
[VISITOR] New ExprFunctor, StmtFunctor Interface. Modular analysis (#58)
Browse files Browse the repository at this point in the history
* [ARITH/VISITOR] Modular Analysis, ExprFunctor, StmtFunctor

* retrigger

* [IRFunctor] Migrated CodegenC

* [IRFUNCTOR] Migrate CodeGenLLVM

* [IRFunctor] Migrate canonical

* [IRFunctor] Migrate vectorize

* [IRFunctor] migrate CodeGenStackVM
  • Loading branch information
tqchen authored Mar 1, 2017
1 parent e438794 commit 7133448
Show file tree
Hide file tree
Showing 25 changed files with 2,028 additions and 1,471 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ after_failure:
- tests/travis/travis_after_failure.sh

notifications:
# Emails are sent to the committer's git-configured email address by default,
email:
on_success: change
on_failure: always
261 changes: 261 additions & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*!
* Copyright (c) 2017 by Contributors
* \file ir_functor_ext.h
* \brief More powerful Visitor that allows define function signatures.
*/
#ifndef TVM_IR_FUNCTOR_EXT_H_
#define TVM_IR_FUNCTOR_EXT_H_

#include <tvm/ir_functor.h>
#include "./ir.h"

namespace tvm {
namespace ir {

/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \code
* // A functor that set variable to b. and calculate results.
* class MyExprFunctor
* : public ir::ExprFunctor<int(const Expr&, int)> {
* public:
* int VisitExpr_(const Variable* op, int b) final {
* return b;
* }
* int VisitExpr_(const IntImm* op, int b) final {
* return op->value;
* }
* int VisitExpr_(const Add* op, int b) final {
* return Visit(op->a, b) + Visit(op->b, b);
* }
* };
* MyExprFunctor f;
* Var x("x");
* CHECK_EQ(f(x + 1, 2), 3);
* \endcode
*
* \note Why do we need this more powerful Functor:
*
* We often need to implement a transformer tasks.
* Say we want to take Expr and transform it to some analysis result,
* This easily be done incorrectly using plain Visitor. See IRVisitor's
* document for possible error cases.
*
* \tparam FType function signiture
* This type if only defined for FType with function signiture R(const Expr&, Args...)
*/
template<typename FType>
class ExprFunctor;
/*!
* \brief Same as ExprFunctor except it is applied on statements
* \tparam FType The function signature.
*/
template<typename FType>
class StmtFunctor;

// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT { \
return VisitExprDefault_(op, std::forward<Args>(args)...); \
}
#define STMT_FUNCTOR_DEFAULT { \
return VisitStmtDefault_(op, std::forward<Args>(args)...); \
}

#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \

#define IR_STMT_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitStmt_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \

template<typename R, typename ...Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const Variable* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Load* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Let* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Call* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Add* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Sub* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mul* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Div* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Mod* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Min* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Max* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const EQ* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const NE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GT* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GE* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const And* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Or* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Reduce* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Cast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Not* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Select* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Ramp* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const Broadcast* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const UIntImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FloatImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const StringImm* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(Variable);
IR_EXPR_FUNCTOR_DISPATCH(Load);
IR_EXPR_FUNCTOR_DISPATCH(Let);
IR_EXPR_FUNCTOR_DISPATCH(Call);
IR_EXPR_FUNCTOR_DISPATCH(Add);
IR_EXPR_FUNCTOR_DISPATCH(Sub);
IR_EXPR_FUNCTOR_DISPATCH(Mul);
IR_EXPR_FUNCTOR_DISPATCH(Div);
IR_EXPR_FUNCTOR_DISPATCH(Mod);
IR_EXPR_FUNCTOR_DISPATCH(Min);
IR_EXPR_FUNCTOR_DISPATCH(Max);
IR_EXPR_FUNCTOR_DISPATCH(EQ);
IR_EXPR_FUNCTOR_DISPATCH(NE);
IR_EXPR_FUNCTOR_DISPATCH(LT);
IR_EXPR_FUNCTOR_DISPATCH(LE);
IR_EXPR_FUNCTOR_DISPATCH(GT);
IR_EXPR_FUNCTOR_DISPATCH(GE);
IR_EXPR_FUNCTOR_DISPATCH(And);
IR_EXPR_FUNCTOR_DISPATCH(Or);
IR_EXPR_FUNCTOR_DISPATCH(Reduce);
IR_EXPR_FUNCTOR_DISPATCH(Cast);
IR_EXPR_FUNCTOR_DISPATCH(Not);
IR_EXPR_FUNCTOR_DISPATCH(Select);
IR_EXPR_FUNCTOR_DISPATCH(Ramp);
IR_EXPR_FUNCTOR_DISPATCH(Broadcast);
IR_EXPR_FUNCTOR_DISPATCH(IntImm);
IR_EXPR_FUNCTOR_DISPATCH(UIntImm);
IR_EXPR_FUNCTOR_DISPATCH(FloatImm);
IR_EXPR_FUNCTOR_DISPATCH(StringImm);
return vtable;
}
};

template<typename R, typename ...Args>
class StmtFunctor<R(const Stmt& n, Args... args)> {
private:
using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
using FType = IRFunctor<R(const NodeRef& n, TSelf* self, Args... args)>;

public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~StmtFunctor() {}
/*!
* \brief Same as call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Stmt& n, Args... args) {
return VisitStmt(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The stmt node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitStmt(const Stmt& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitStmt_(const LetStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AttrStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const IfThenElse* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const For* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Allocate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Store* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Free* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmt* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumer* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Provide* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Realize* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Block* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const Evaluate* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Node* op, Args ...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}

private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
IR_STMT_FUNCTOR_DISPATCH(LetStmt);
IR_STMT_FUNCTOR_DISPATCH(AttrStmt);
IR_STMT_FUNCTOR_DISPATCH(IfThenElse);
IR_STMT_FUNCTOR_DISPATCH(For);
IR_STMT_FUNCTOR_DISPATCH(Allocate);
IR_STMT_FUNCTOR_DISPATCH(Store);
IR_STMT_FUNCTOR_DISPATCH(Free);
IR_STMT_FUNCTOR_DISPATCH(AssertStmt);
IR_STMT_FUNCTOR_DISPATCH(ProducerConsumer);
IR_STMT_FUNCTOR_DISPATCH(Provide);
IR_STMT_FUNCTOR_DISPATCH(Realize);
IR_STMT_FUNCTOR_DISPATCH(Block);
IR_STMT_FUNCTOR_DISPATCH(Evaluate);
return vtable;
}
};

#undef IR_STMT_FUNCTOR_DISPATCH
#undef IR_EXPR_FUNCTOR_DISPATCH
#undef EXPR_FUNCTOR_DEFAULT
#undef STMT_FUNCTOR_DEFAULT

} // namespace ir
} // namespace tvm
#endif // TVM_IR_FUNCTOR_EXT_H_
62 changes: 0 additions & 62 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,59 +55,23 @@ class IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const Variable* op, const Stmt& s);
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Load* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Let* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const Call* op, const Stmt& s);
virtual Stmt Mutate_(const Add* op, const Stmt& e);
virtual Stmt Mutate_(const Sub* op, const Stmt& e);
virtual Stmt Mutate_(const Mul* op, const Stmt& e);
virtual Stmt Mutate_(const Div* op, const Stmt& e);
virtual Stmt Mutate_(const Mod* op, const Stmt& e);
virtual Stmt Mutate_(const Min* op, const Stmt& e);
virtual Stmt Mutate_(const Max* op, const Stmt& e);
virtual Stmt Mutate_(const EQ* op, const Stmt& e);
virtual Stmt Mutate_(const NE* op, const Stmt& e);
virtual Stmt Mutate_(const LT* op, const Stmt& e);
virtual Stmt Mutate_(const LE* op, const Stmt& e);
virtual Stmt Mutate_(const GT* op, const Stmt& e);
virtual Stmt Mutate_(const GE* op, const Stmt& e);
virtual Stmt Mutate_(const And* op, const Stmt& e);
virtual Stmt Mutate_(const Or* op, const Stmt& e);
virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
virtual Stmt Mutate_(const Cast* op, const Stmt& s);
virtual Stmt Mutate_(const Not* op, const Stmt& s);
virtual Stmt Mutate_(const Select* op, const Stmt& s);
virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
virtual Stmt Mutate_(const Provide* op, const Stmt& e);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
virtual Stmt Mutate_(const StringImm* op, const Stmt& e);

virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const LetStmt* op, const Expr& e);
virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
virtual Expr Mutate_(const For* op, const Expr& e);
virtual Expr Mutate_(const Allocate* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& e);
virtual Expr Mutate_(const Store* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
virtual Expr Mutate_(const Free* op, const Expr& e);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Add* op, const Expr& e);
virtual Expr Mutate_(const Sub* op, const Expr& e);
Expand All @@ -130,38 +94,12 @@ class IRMutator {
virtual Expr Mutate_(const Select* op, const Expr& e);
virtual Expr Mutate_(const Ramp* op, const Expr& e);
virtual Expr Mutate_(const Broadcast* op, const Expr& e);
virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
virtual Expr Mutate_(const Provide* op, const Expr& e);
virtual Expr Mutate_(const Realize* op, const Expr& e);
virtual Expr Mutate_(const Block* op, const Expr& e);
virtual Expr Mutate_(const Evaluate* op, const Expr& e);
virtual Expr Mutate_(const IntImm* op, const Expr& e);
virtual Expr Mutate_(const UIntImm* op, const Expr& e);
virtual Expr Mutate_(const FloatImm* op, const Expr& e);
virtual Expr Mutate_(const StringImm* op, const Expr& e);
};

/*!
* \brief Example on how to subclass and override behavior of IRMutator
*/
class IRMutatorExample : public IRMutator {
public:
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = IRMutatorExample::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
Stmt Mutate(Stmt stmt) final {
static const FMutateStmt& f = IRMutatorExample::vtable_stmt();
return (f.can_dispatch(stmt) ?
f(stmt, stmt, this) : IRMutator::Mutate(stmt));
}
// to be implemented by child class
static FMutateExpr& vtable_expr(); // NOLINT(*)
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};

} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
Loading

0 comments on commit 7133448

Please sign in to comment.