Skip to content

Commit

Permalink
Switch from CompileEngine to TECompiler in Interpreter
Browse files Browse the repository at this point in the history
This continues on:
https://discuss.tvm.apache.org/t/rfc-relay-tecompiler-rewrite-existing-compile-engine-to-match-updated-compiler-flow/9233
and #751, this time just replacing CompileEngine with TECompiler in the Interpreter,
using the JIT helper added to help the transition.

Some whitespace improvements while there.
  • Loading branch information
mbs-octoml committed Jul 16, 2021
1 parent cba9cf3 commit fd9e188
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 34 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* \brief Indicate the compiler that should be used for building this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
Expand Down
17 changes: 8 additions & 9 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,14 @@ inline Device RemoveRPCSessionMask(Device dev) {
return dev;
}

inline std::ostream& operator<<(std::ostream& os, DLDevice dev);
inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
if (tvm::runtime::IsRPCSessionDevice(dev)) {
os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
dev = tvm::runtime::RemoveRPCSessionMask(dev);
}
os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
return os;
}

/*!
* \brief Add a RPC session mask to a Device.
Expand All @@ -308,14 +315,6 @@ inline Device AddRPCSessionMask(Device dev, int session_table_index) {
return dev;
}

inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
if (IsRPCSessionDevice(dev)) {
os << "remote[" << GetRPCSessionIndex(dev) << "]-";
dev = RemoveRPCSessionMask(dev);
}
os << runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
return os;
}
} // namespace runtime
} // namespace tvm

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/backend/graph_executor_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
The compiler is built from a few pieces.
First we define a compiler from a single Relay expression to the
graph langauge. We require the expression to be a function.
graph language. We require the expression to be a function.
The function's parameters correspond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
The compiler's output is a program in the graph language, which is composed of
graph langauge is composed of Node, NodeRef, InputNode, OpNode.
This "little language" represents programs in TVM's graph format.
Node, NodeRef, InputNode, OpNode. This "little language" represents programs in
TVM's graph format.
To connect to the graph executor, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
Expand Down
1 change: 0 additions & 1 deletion src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ Source::Source(SourceName src_name, std::string source) {
// NB(@jroesch):
std::string source_str = n->source;
for (auto c : source_str) {
DLOG(INFO) << "char=" << c;
if (c == '\n') {
// Record the length of the line.
n->line_map.back().second = length;
Expand Down
27 changes: 13 additions & 14 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "../transforms/pass_utils.h"
#include "compile_engine.h"
#include "te_compiler.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -214,9 +215,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(IRModule mod, Device device, Target target)
: mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {
engine_ = CompileEngine::Global();
}
: mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {}

template <typename T>
T WithFrame(const Frame& fr, const std::function<T()>& f) {
Expand Down Expand Up @@ -286,7 +285,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

Array<Shape> ComputeDynamicShape(const Function& func, const Array<ObjectRef>& args) {
CCacheKey key(func, Target("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
auto cfunc = compiler_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();

std::vector<TVMValue> values(arity);
Expand Down Expand Up @@ -485,7 +484,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
out_shapes = ComputeDynamicShape(func, args);
}

PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_));
PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
ICHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Expand Down Expand Up @@ -555,11 +554,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// We should not find operators after running fusion,
// and operator lowering.
//
// We have some functions cotaining chunks of operators
// We have some functions containing chunks of operators
// which will be loaded into operator map.
if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try "
<< "; operators should have been removed by previous passes; try "
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
Expand All @@ -569,9 +568,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef fn_val = Eval(call->op);
if (const InterpreterClosureObj* closure_node = fn_val.as<InterpreterClosureObj>()) {
auto closure = GetRef<InterpreterClosure>(closure_node);
return this->Invoke(closure, args);
return Invoke(closure, args);
} else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
return Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
Expand Down Expand Up @@ -710,17 +709,17 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
Target target_;
// Object stack.
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
// TE-to-TIR lowerer (compiler).
TECompiler compiler_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};

TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device, Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
transform::Sequential seq({transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false),
transform::Sequential seq({// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType()});

transform::PassContext pass_ctx = transform::PassContext::Current();
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class LowerTensorExpr : public ExprMutator {
tir_call_attrs->metadata.Set("relay_attrs", func->attrs);

Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs));
return ret_call;
return std::move(ret_call);
}

IRModule module_;
Expand Down
11 changes: 6 additions & 5 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ using ProcessFn = std::function<void(Function)>;

/*!
* \brief A compiler which lowers primitive Relay functions to tensor expressions
* and schdules them into TIR functions.
* and schedules them into TIR functions.
*/
class TECompilerNode : public Object {
public:
Expand Down Expand Up @@ -178,10 +178,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR.
*
* /param module The IRModule.
* /param targets The mapping for devices to targets.
* /param device_map An analysis result mapping each sub-expression to a device.
* /return The lowered module, see above.
* \param compiler The TE-to-TIR compliler (which caches lowered functions)
* \param module The IRModule.
* \param targets The mapping for devices to targets.
* \param device_map An analysis result mapping each sub-expression to a device.
* \return The lowered module, see above.
*/
// TODO(@electriclilies): Not sure if this default initialization is correct...
LoweredModule LowerTE(
Expand Down

0 comments on commit fd9e188

Please sign in to comment.