Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Add memoized expr translator for use by backend codegen #5325

Merged
merged 1 commit into from
Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 23 additions & 41 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"

#include <topi/tags.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>

#include <topi/tags.h>
#include <utility>
#include <functional>
#include <limits>
#include <mutex>
#include <functional>
#include <vector>
#include <unordered_map>
#include <utility>
#include <vector>

#include "compile_engine.h"
#include "utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {

// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
public ExprFunctor<Array<te::Tensor>(const Expr&)> {
class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleGetter(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
Expand Down Expand Up @@ -179,17 +180,6 @@ class ScheduleGetter :
return CachedFunc(cache_node);
}

Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
memo_[expr] = res;
return res;
}
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
Expand Down Expand Up @@ -327,15 +317,14 @@ class ScheduleGetter :
int master_op_pattern_{0};
OpImplementation master_implementation_;
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
Array<te::Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};

// Creates shape function from functor.
class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
MakeShapeFunc() {}

Expand Down Expand Up @@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return std::make_pair(schedule, cfunc);
}

Array<te::Tensor> VisitExpr(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) {
return it->second;
} else {
Array<te::Tensor> res = ExprFunctor::VisitExpr(expr);
if (expr.as<VarNode>() == nullptr) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_[expr] = res;
}
return res;
Array<te::Tensor> VisitExpr(const Expr& expr) final {
if (expr.as<VarNode>()) {
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
return ExprFunctor::VisitExpr(expr);
}
// For other case, do memoized visit
return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
}

Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
Expand Down Expand Up @@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_data_;
/*! \brief Map from parameter to list of shape placeholder */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> param_shapes_;
/*! \brief Memoized visit result */
std::unordered_map<Expr, Array<te::Tensor>, ObjectHash, ObjectEqual> memo_;
/*! \brief Stack of data dependencies for shape function */
std::vector<bool> data_dependants_;
/*! \brief Scalars used in the shape function */
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/contrib/codegen_c/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
*/
class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; }

std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}

std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey();
return {};
Expand Down Expand Up @@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> func_decl_;
/*! \brief The declaration statements of buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The name and index pairs for output. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};

class CSourceCodegen : public CSourceModuleCodegenBase {
Expand Down
12 changes: 1 addition & 11 deletions src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {

// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
public CodegenCBase {
class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
explicit CodegenDNNL(const std::string& id) { this->ext_func_id_ = id; }

std::vector<Output> VisitExpr(const Expr& expr) final {
if (visited_.count(expr)) return visited_.at(expr);
std::vector<Output> output = ExprFunctor::VisitExpr(expr);
visited_[expr] = output;
return output;
}

std::vector<Output> VisitExprDefault_(const Object* op) final {
LOG(FATAL) << "DNNL codegen doesn't support: " << op->GetTypeKey();
return {};
Expand Down Expand Up @@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std::vector<std::string> ext_func_body;
/*! \brief The declaration of intermeidate buffers. */
std::vector<std::string> buf_decl_;
/*! \brief The cached expressions. */
std::unordered_map<Expr, std::vector<Output>, ObjectHash, ObjectEqual> visited_;
};

/*!
Expand Down
50 changes: 3 additions & 47 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>


#include <list>
#include <string>
#include <vector>

#include "utils.h"
#include "compile_engine.h"
#include "utils.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
};

/*! \brief Code generator for graph runtime */
class GraphRuntimeCodegen
: public ::tvm::relay::ExprFunctor<std::vector<GraphNodeRef>(const Expr&)> {
class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
public:
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets)
: mod_(mod) {
GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : mod_(mod) {
compile_engine_ = CompileEngine::Global();
targets_ = targets;
}
Expand Down Expand Up @@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return {GraphNodeRef(node_id, 0)};
}

/*! \brief Visitors */
std::unordered_map<Expr, std::vector<GraphNodeRef>, ObjectHash, ObjectEqual> visitor_cache_;

std::vector<GraphNodeRef> VisitExpr(const Expr& expr) override {
if (visitor_cache_.count(expr)) return visitor_cache_.at(expr);
std::vector<GraphNodeRef> res;
if (expr.as<ConstantNode>()) {
res = VisitExpr_(expr.as<ConstantNode>());
} else if (expr.as<TupleNode>()) {
res = VisitExpr_(expr.as<TupleNode>());
} else if (expr.as<VarNode>()) {
res = VisitExpr_(expr.as<VarNode>());
} else if (expr.as<GlobalVarNode>()) {
res = VisitExpr_(expr.as<GlobalVarNode>());
} else if (expr.as<FunctionNode>()) {
res = VisitExpr_(expr.as<FunctionNode>());
} else if (expr.as<CallNode>()) {
res = VisitExpr_(expr.as<CallNode>());
} else if (expr.as<LetNode>()) {
res = VisitExpr_(expr.as<LetNode>());
} else if (expr.as<IfNode>()) {
res = VisitExpr_(expr.as<IfNode>());
} else if (expr.as<OpNode>()) {
res = VisitExpr_(expr.as<OpNode>());
} else if (expr.as<TupleGetItemNode>()) {
res = VisitExpr_(expr.as<TupleGetItemNode>());
} else if (expr.as<RefCreateNode>()) {
res = VisitExpr_(expr.as<RefCreateNode>());
} else if (expr.as<RefReadNode>()) {
res = VisitExpr_(expr.as<RefReadNode>());
} else if (expr.as<RefWriteNode>()) {
res = VisitExpr_(expr.as<RefWriteNode>());
} else if (expr.as<ConstructorNode>()) {
res = VisitExpr_(expr.as<ConstructorNode>());
} else if (expr.as<MatchNode>()) {
res = VisitExpr_(expr.as<MatchNode>());
}
visitor_cache_[expr] = res;
return res;
}

std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op);
return var_map_[expr.get()];
Expand Down
5 changes: 0 additions & 5 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,6 @@ class Interpreter :
return VisitExpr(expr);
}

ObjectRef VisitExpr(const Expr& expr) final {
auto ret = ExprFunctor<ObjectRef(const Expr& n)>::VisitExpr(expr);
return ret;
}

ObjectRef VisitExpr_(const VarNode* var_node) final {
return Lookup(GetRef<Var>(var_node));
}
Expand Down
35 changes: 35 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
Expand All @@ -42,6 +43,40 @@
namespace tvm {
namespace relay {
namespace backend {

/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
*/
template <typename OutputType>
class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;

public:
/*! \brief virtual destructor */
virtual ~MemoizedExprTranslator() {}

/*!
* \brief The memoized call.
* \param n The expression node.
* \return The result of the call
*/
virtual OutputType VisitExpr(const Expr& n) {
CHECK(n.defined());
auto it = memo_.find(n);
if (it != memo_.end()) {
return it->second;
}
auto res = BaseFunctor::VisitExpr(n);
memo_[n] = res;
return res;
}

protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, OutputType, ObjectHash, ObjectEqual> memo_;
};

/*!
* \brief Get the Packed Func
*
Expand Down