diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 4ed8fbc15abd..ce0a314f265b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -21,29 +21,31 @@ * \file relay/backend/compile_engine.cc * \brief Internal compialtion engine. */ +#include "compile_engine.h" + +#include +#include #include -#include -#include -#include -#include -#include -#include #include +#include #include #include #include #include -#include +#include +#include +#include +#include +#include -#include -#include +#include #include #include -#include -#include #include +#include +#include -#include "compile_engine.h" +#include "utils.h" namespace tvm { namespace relay { @@ -111,8 +113,7 @@ Array GetShape(const Array& shape) { // The getter to get schedule from compile engine. // Get schedule from functor. -class ScheduleGetter : - public ExprFunctor(const Expr&)> { +class ScheduleGetter : public backend::MemoizedExprTranslator> { public: explicit ScheduleGetter(Target target) : target_(target), device_copy_op_(Op::Get("device_copy")) {} @@ -179,17 +180,6 @@ class ScheduleGetter : return CachedFunc(cache_node); } - Array VisitExpr(const Expr& expr) { - auto it = memo_.find(expr); - if (it != memo_.end()) { - return it->second; - } else { - Array res = ExprFunctor::VisitExpr(expr); - memo_[expr] = res; - return res; - } - } - Array VisitExpr_(const VarNode* op) final { LOG(FATAL) << "Free variable " << op->name_hint(); return {}; @@ -327,7 +317,6 @@ class ScheduleGetter : int master_op_pattern_{0}; OpImplementation master_implementation_; std::ostringstream readable_name_stream_; - std::unordered_map, ObjectHash, ObjectEqual> memo_; Array scalars_; // Cache device copy op for equivalence checking to reduce registry lookup // overhead for each invocation of call node when retrieving schedules. @@ -335,7 +324,7 @@ class ScheduleGetter : }; // Creates shape function from functor. -class MakeShapeFunc : public ExprFunctor(const Expr&)> { +class MakeShapeFunc : public backend::MemoizedExprTranslator> { public: MakeShapeFunc() {} @@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { return std::make_pair(schedule, cfunc); } - Array VisitExpr(const Expr& expr) { - auto it = memo_.find(expr); - if (it != memo_.end()) { - return it->second; - } else { - Array res = ExprFunctor::VisitExpr(expr); - if (expr.as() == nullptr) { - // Do not memoize vars because shape functions could use either the data - // or the shape of a var each time. - memo_[expr] = res; - } - return res; + Array VisitExpr(const Expr& expr) final { + if (expr.as()) { + // Do not memoize vars because shape functions could use either the data + // or the shape of a var each time. + return ExprFunctor::VisitExpr(expr); } + // For other case, do memoized visit + return backend::MemoizedExprTranslator>::VisitExpr(expr); } Array VisitExpr_(const VarNode* var_node) final { @@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor(const Expr&)> { std::unordered_map, ObjectHash, ObjectEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ std::unordered_map, ObjectHash, ObjectEqual> param_shapes_; - /*! \brief Memoized visit result */ - std::unordered_map, ObjectHash, ObjectEqual> memo_; /*! \brief Stack of data dependencies for shape function */ std::vector data_dependants_; /*! \brief Scalars used in the shape function */ diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index fc93b73e5ca9..0b3510c85779 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -40,18 +40,10 @@ using namespace backend; * purpose. Only several binary options are covered. Users * may need to extend them to cover more operators. */ -class CodegenC : public ExprFunctor(const Expr&)>, - public CodegenCBase { +class CodegenC : public MemoizedExprTranslator>, public CodegenCBase { public: explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } - std::vector VisitExpr(const Expr& expr) final { - if (visited_.count(expr)) return visited_.at(expr); - std::vector output = ExprFunctor::VisitExpr(expr); - visited_[expr] = output; - return output; - } - std::vector VisitExprDefault_(const Object* op) final { LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey(); return {}; @@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor(const Expr&)>, std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; - /*! \brief The name and index pairs for output. */ - std::unordered_map, ObjectHash, ObjectEqual> visited_; }; class CSourceCodegen : public CSourceModuleCodegenBase { diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 48652fc19f37..26bc8786902c 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -128,18 +128,10 @@ std::vector Add(const CallNode* call) { // TODO(@zhiics, @comaniac): This is a basic implementation. We should implement // all utilities and make a base class for users to implement. -class CodegenDNNL : public ExprFunctor(const Expr&)>, - public CodegenCBase { +class CodegenDNNL : public MemoizedExprTranslator>, public CodegenCBase { public: explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; } - std::vector VisitExpr(const Expr& expr) final { - if (visited_.count(expr)) return visited_.at(expr); - std::vector output = ExprFunctor::VisitExpr(expr); - visited_[expr] = output; - return output; - } - std::vector VisitExprDefault_(const Object* op) final { LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey(); return {}; @@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor(const Expr&)>, std::vector ext_func_body; /*! \brief The declaration of intermeidate buffers. */ std::vector buf_decl_; - /*! \brief The cached expressions. */ - std::unordered_map, ObjectHash, ObjectEqual> visited_; }; /*! diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 4279db0110d8..7b686c76e3e7 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -28,13 +28,12 @@ #include #include - #include #include #include -#include "utils.h" #include "compile_engine.h" +#include "utils.h" namespace tvm { namespace relay { @@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode { }; /*! \brief Code generator for graph runtime */ -class GraphRuntimeCodegen - : public ::tvm::relay::ExprFunctor(const Expr&)> { +class GraphRuntimeCodegen : public backend::MemoizedExprTranslator> { public: - GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) - : mod_(mod) { + GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) { compile_engine_ = CompileEngine::Global(); targets_ = targets; } @@ -313,47 +310,6 @@ class GraphRuntimeCodegen return {GraphNodeRef(node_id, 0)}; } - /*! \brief Visitors */ - std::unordered_map, ObjectHash, ObjectEqual> visitor_cache_; - - std::vector VisitExpr(const Expr& expr) override { - if (visitor_cache_.count(expr)) return visitor_cache_.at(expr); - std::vector res; - if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } else if (expr.as()) { - res = VisitExpr_(expr.as()); - } - visitor_cache_[expr] = res; - return res; - } - std::vector VisitExpr_(const VarNode* op) override { Expr expr = GetRef(op); return var_map_[expr.get()]; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 631f2d433be5..465f788449e2 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -244,11 +244,6 @@ class Interpreter : return VisitExpr(expr); } - ObjectRef VisitExpr(const Expr& expr) final { - auto ret = ExprFunctor::VisitExpr(expr); - return ret; - } - ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a96ffe4720fc..65e6ae9e79c6 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,40 @@ namespace tvm { namespace relay { namespace backend { + +/*! + * \brief A simple wrapper around ExprFunctor for a single argument case. + * The result of visit is memoized. + */ +template +class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor { + using BaseFunctor = ::tvm::relay::ExprFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~MemoizedExprTranslator() {} + + /*! + * \brief The memoized call. + * \param n The expression node. + * \return The result of the call + */ + virtual OutputType VisitExpr(const Expr& n) { + CHECK(n.defined()); + auto it = memo_.find(n); + if (it != memo_.end()) { + return it->second; + } + auto res = BaseFunctor::VisitExpr(n); + memo_[n] = res; + return res; + } + + protected: + /*! \brief Internal map used for memoization. */ + std::unordered_map memo_; +}; + /*! * \brief Get the Packed Func *