Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY]Switch from CompileEngine to TECompiler in Interpreter #8486

Merged
merged 1 commit into from
Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* \brief Indicate the compiler that should be used for building this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
Expand Down
17 changes: 8 additions & 9 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,14 @@ inline Device RemoveRPCSessionMask(Device dev) {
return dev;
}

inline std::ostream& operator<<(std::ostream& os, DLDevice dev);
inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
if (tvm::runtime::IsRPCSessionDevice(dev)) {
os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
dev = tvm::runtime::RemoveRPCSessionMask(dev);
}
os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
return os;
}

/*!
* \brief Add a RPC session mask to a Device.
Expand All @@ -308,14 +315,6 @@ inline Device AddRPCSessionMask(Device dev, int session_table_index) {
return dev;
}

inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
if (IsRPCSessionDevice(dev)) {
os << "remote[" << GetRPCSessionIndex(dev) << "]-";
dev = RemoveRPCSessionMask(dev);
}
os << runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
return os;
}
} // namespace runtime
} // namespace tvm

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/backend/graph_executor_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
The compiler is built from a few pieces.
First we define a compiler from a single Relay expression to the
graph langauge. We require the expression to be a function.
graph language. We require the expression to be a function.
The function's parameters correspond to the placeholder/inputs
and model parameters found in the computation graph representation.
The body of the function represents the computation graph.
The compiler's output is a program in the graph language, which is composed of
graph langauge is composed of Node, NodeRef, InputNode, OpNode.
This "little language" represents programs in TVM's graph format.
Node, NodeRef, InputNode, OpNode. This "little language" represents programs in
TVM's graph format.
To connect to the graph executor, we use a printer that converts our graph format
into TVM's JSON format. The resulting string can be loaded by
Expand Down
1 change: 0 additions & 1 deletion src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ Source::Source(SourceName src_name, std::string source) {
// NB(@jroesch):
std::string source_str = n->source;
for (auto c : source_str) {
DLOG(INFO) << "char=" << c;
if (c == '\n') {
// Record the length of the line.
n->line_map.back().second = length;
Expand Down
27 changes: 13 additions & 14 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "../transforms/pass_utils.h"
#include "compile_engine.h"
#include "te_compiler.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -214,9 +215,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
Interpreter(IRModule mod, Device device, Target target)
: mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {
engine_ = CompileEngine::Global();
}
: mod_(mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {}

template <typename T>
T WithFrame(const Frame& fr, const std::function<T()>& f) {
Expand Down Expand Up @@ -286,7 +285,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

Array<Shape> ComputeDynamicShape(const Function& func, const Array<ObjectRef>& args) {
CCacheKey key(func, Target("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
auto cfunc = compiler_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();

std::vector<TVMValue> values(arity);
Expand Down Expand Up @@ -485,7 +484,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
out_shapes = ComputeDynamicShape(func, args);
}

PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_));
PackedFunc packed_func = compiler_->JIT(CCacheKey(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
ICHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Expand Down Expand Up @@ -555,11 +554,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
// We should not find operators after running fusion,
// and operator lowering.
//
// We have some functions cotaining chunks of operators
// We have some functions containing chunks of operators
// which will be loaded into operator map.
if (const auto* op_node = call->op.as<OpNode>()) {
LOG(FATAL) << "found " << op_node->name
<< "; operators should be removed by future passes; try "
<< "; operators should have been removed by previous passes; try "
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
Expand All @@ -569,9 +568,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef fn_val = Eval(call->op);
if (const InterpreterClosureObj* closure_node = fn_val.as<InterpreterClosureObj>()) {
auto closure = GetRef<InterpreterClosure>(closure_node);
return this->Invoke(closure, args);
return Invoke(closure, args);
} else if (const RecClosureObj* closure_node = fn_val.as<RecClosureObj>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
return Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
Expand Down Expand Up @@ -710,17 +709,17 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
Target target_;
// Object stack.
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
// TE-to-TIR lowerer (compiler).
TECompiler compiler_;
Copy link
Contributor

@xqdan xqdan Jul 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not related this pr, but is it possible to hold a customzied TECompiler? Say if it is NPU TECompiler, which has differnent pass flow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultimately we want to enable customized rewrites/transforms at all stages. This work is just getting us out of the go-straight-to-packed-func world so we can have a representation which makes the current te visible.

This is my first week so I'm just wrapping my head around this -- I defer to Jarad for the True Story.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xqdan the goal of TECompile is to have one lowering which applies initial scheduling to functions and puts them in the IRModule. After this users can customize the flow for specific targets or executors.

// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
};

TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device, Target target) {
if (mod.defined()) {
// eta expand to support constructors in argument position
transform::Sequential seq({transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false),
transform::Sequential seq({// eta expand to support constructors in argument position
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType()});

transform::PassContext pass_ctx = transform::PassContext::Current();
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class LowerTensorExpr : public ExprMutator {
tir_call_attrs->metadata.Set("relay_attrs", func->attrs);

Expr ret_call = Call(lowered_func->prim_fn_var, args, Attrs(tir_call_attrs));
return ret_call;
return std::move(ret_call);
}

IRModule module_;
Expand Down
11 changes: 6 additions & 5 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ using ProcessFn = std::function<void(Function)>;

/*!
* \brief A compiler which lowers primitive Relay functions to tensor expressions
* and schdules them into TIR functions.
* and schedules them into TIR functions.
*/
class TECompilerNode : public Object {
public:
Expand Down Expand Up @@ -178,10 +178,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR.
*
* /param module The IRModule.
* /param targets The mapping for devices to targets.
* /param device_map An analysis result mapping each sub-expression to a device.
* /return The lowered module, see above.
* \param compiler The TE-to-TIR compliler (which caches lowered functions)
* \param module The IRModule.
* \param targets The mapping for devices to targets.
* \param device_map An analysis result mapping each sub-expression to a device.
* \return The lowered module, see above.
*/
// TODO(@electriclilies): Not sure if this default initialization is correct...
LoweredModule LowerTE(
Expand Down