Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][codegen] VM external codegen #4544

Merged
merged 1 commit into from
Dec 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 52 additions & 29 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,30 +476,39 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
argument_registers.push_back(reg->second);
}

// Next generate the invoke instruction.
Target target;
if (targets_.size() == 1) {
// homogeneous execution.
for (auto kv : targets_) {
target = kv.second;
}

if (!func->UseDefaultCompiler()) {
target = tvm::target::ext_dev();
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
// Next generate the invoke instruction.
if (targets_.size() == 1) {
// homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation";
}
}

auto key = CCacheKeyNode::make(func, target);
auto cfunc = engine_->Lower(key);

// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
if (!func->UseDefaultCompiler()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
}

Emit(Instruction::InvokePacked(op_index,
Expand Down Expand Up @@ -950,32 +959,46 @@ void VMCompiler::LibraryCodegen() {
if (cached_funcs.size() == 0) {
return;
}
std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
for (auto &cfunc : cached_funcs) {
std::unordered_map<std::string, Array<LoweredFunc>> funcs;
for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
if (tgt_funcs.count(target_str) == 0) {
tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
if (target_str == "ext_dev") {
continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
} else {
tgt_funcs[target_str].push_back(cfunc->funcs[0]);
funcs[target_str].push_back(cfunc->funcs[0]);
}
}
Map<Target, Array<LoweredFunc>> funcs;
for (auto &it : tgt_funcs) {
funcs.Set(Target::Create(it.first), it.second);
}

if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
// The target is just a dummy arg because funcs already contains corresponding target
// therefore target won't be used in the build function
runtime::Module mod = (*f)(funcs, Target(), target_host_);
auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module mod;
if (funcs.size() > 0) {
mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current());
CHECK(mod.operator->());
exec_->lib = mod;
} else {
LOG(FATAL) << "relay.backend.build is not registered";
CHECK_EQ(ext_mods.size(), 1U)
<< "Expect to have a TVM DSOModule when multiple runtime modules exist";
}
if (!ext_mods.empty()) {
if (funcs.size() == 0) {
mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (auto it : ext_mods) {
mod.Import(it);
}
}
}
exec_->lib = mod;
size_t primitive_index = 0;
for (auto cfunc : cached_funcs) {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,9 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
if (packed_funcs_.size() <= packed_index) {
packed_funcs_.resize(packed_index + 1);
}
packed_funcs_[packed_index] = lib.GetFunction(packed_name);
tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true);
CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name;
packed_funcs_[packed_index] = pf;
}
}

Expand Down
70 changes: 44 additions & 26 deletions tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,54 @@
from tvm import relay
from tvm.contrib import util

def check_result(mod, map_inputs, out_shape, result, tol=1e-5):
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
ctx=tvm.cpu()):
if sys.platform == "win32":
print("Skip test on Windows for now")
return

with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, "llvm")
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")

kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

ctx = tvm.cpu()
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

for name, data in map_inputs.items():
rt_mod.set_input(name, data)

rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def update_lib(lib):
test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")

kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = tvm.module.load(lib_path)

return lib

def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
code, lib = exe.save()
lib = update_lib(lib)
exe = relay.vm.Executable.load_exec(code, lib)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)

def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, _ = relay.build(mod, target=target)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)

for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)

tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)

check_vm_result()
check_graph_runtime_result()


def set_external_func_attr(func, compiler, ext_symbol):
Expand Down