Skip to content

Commit

Permalink
add memoized expr translator for use by backend codegen (#5325)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Apr 14, 2020
1 parent 0ab1803 commit 2c1ca60
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 115 deletions.
64 changes: 23 additions & 41 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"

#include <topi/tags.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>

#include <topi/tags.h>
#include <utility>
#include <functional>
#include <limits>
#include <mutex>
#include <functional>
#include <vector>
#include <unordered_map>
#include <utility>
#include <vector>

#include "compile_engine.h"
#include "utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {

// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
public ExprFunctor<Array<te::Tensor>(const Expr&)> {
class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
Expand Down Expand Up @@ -179,17 +180,6 @@ class ScheduleGetter :
return CachedFunc(cache_node);
}

Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
memo_[expr] = res;
return res;
}
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
Expand Down Expand Up @@ -327,15 +317,14 @@ class ScheduleGetter :
int master_op_pattern_{0};
OpImplementation master_implementation_;
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};

// Creates shape function from functor.
class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
MakeShapeFunc() {}

Expand Down Expand Up @@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return std::make_pair(schedule, cfunc);
}

Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
if (expr.as<VarNode>() == nullptr) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_[expr] = res;
}
return res;
Array<te::Tensor> VisitExpr(const Expr& expr) final {
if (expr.as<VarNode>()) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
return ExprFunctor::VisitExpr(expr);
}
// For other case, do memoized visit
return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
}

Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
Expand Down Expand Up @@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
/*! \brief Memoized visit result */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
*/
class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }

std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}

std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
return {};
Expand Down Expand Up @@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};

class CSourceCodegen : public CSourceModuleCodegenBase {
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {

// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }

std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}

std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
return {};
Expand Down Expand Up @@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};

/*!
Expand Down
50 changes: 3 additions & 47 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>


#include <list>
#include <string>
#include <vector>

#include "utils.h"
#include "compile_engine.h"
#include "utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
};

/*! \brief Code generator for graph runtime */
class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
: mod_(mod) {
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
Expand Down Expand Up @@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return {GraphNodeRef(node_id, 0)};
}

/*! \brief Visitors */
std::unordered_map<Expr, std::vector<GraphNodeRef>, ObjectHash, ObjectEqual> visitor_cache_;

std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
std::vector<GraphNodeRef> res;
if (expr.as<ConstantNode>()) {
res = VisitExpr_(expr.as<ConstantNode>());
} else if (expr.as<TupleNode>()) {
res = VisitExpr_(expr.as<TupleNode>());
} else if (expr.as<VarNode>()) {
res = VisitExpr_(expr.as<VarNode>());
} else if (expr.as<GlobalVarNode>()) {
res = VisitExpr_(expr.as<GlobalVarNode>());
} else if (expr.as<FunctionNode>()) {
res = VisitExpr_(expr.as<FunctionNode>());
} else if (expr.as<CallNode>()) {
res = VisitExpr_(expr.as<CallNode>());
} else if (expr.as<LetNode>()) {
res = VisitExpr_(expr.as<LetNode>());
} else if (expr.as<IfNode>()) {
res = VisitExpr_(expr.as<IfNode>());
} else if (expr.as<OpNode>()) {
res = VisitExpr_(expr.as<OpNode>());
} else if (expr.as<TupleGetItemNode>()) {
res = VisitExpr_(expr.as<TupleGetItemNode>());
} else if (expr.as<RefCreateNode>()) {
res = VisitExpr_(expr.as<RefCreateNode>());
} else if (expr.as<RefReadNode>()) {
res = VisitExpr_(expr.as<RefReadNode>());
} else if (expr.as<RefWriteNode>()) {
res = VisitExpr_(expr.as<RefWriteNode>());
} else if (expr.as<ConstructorNode>()) {
res = VisitExpr_(expr.as<ConstructorNode>());
} else if (expr.as<MatchNode>()) {
res = VisitExpr_(expr.as<MatchNode>());
}
visitor_cache_[expr] = res;
return res;
}

std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op);
return var_map_[expr.get()];
Expand Down
5 changes: 0 additions & 5 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,6 @@ class Interpreter :
return VisitExpr(expr);
}

ObjectRef VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
return ret;
}

ObjectRef VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node));
}
Expand Down
35 changes: 35 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
Expand All @@ -42,6 +43,40 @@
namespace tvm {
namespace relay {
namespace backend {

/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
*/
template <typename OutputType>
class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;

public:
/*! \brief virtual destructor */
virtual ~MemoizedExprTranslator() {}

/*!
* \brief The memoized call.
* \param n The expression node.
* \return The result of the call
*/
virtual OutputType VisitExpr(const Expr& n) {
CHECK(n.defined());
auto it = memo_.find(n);
if (it != memo_.end()) {
return it->second;
}
auto res = BaseFunctor::VisitExpr(n);
memo_[n] = res;
return res;
}

protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, OutputType, ObjectHash, ObjectEqual> memo_;
};

/*!
* \brief Get the Packed Func
*
Expand Down

0 comments on commit 2c1ca60

Please sign in to comment.