diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ff64d4a3acbb..80919bc30baf 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -28,8 +28,10 @@ #include #include #include +#include #include +#include "../../target/source/codegen_source_base.h" #include "utils.h" namespace tvm { @@ -437,28 +439,51 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.params = graph_codegen_->GetParams(); auto lowered_funcs = graph_codegen_->GetLoweredFunc(); + + // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { - LOG(WARNING) << "no lowered funcs exist in the compiled module"; + Target target_host = GetTargetHost(); + + // If no target_host has been set, we choose a default one, which is + // llvm if "codegen.LLVMModuleCreate" is accessible. + const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); + if (!target_host.defined()) + target_host = (pf != nullptr) ? target::llvm() : target::stackvm(); + + if (target_host.defined() && target_host->target_name == "llvm") { + // If we can decide the target is LLVM, we then create an empty LLVM module. + ret_.mod = (*pf)(target_host->str(), "empty_module"); + } else { + // If we cannot decide the target is LLVM, we create an empty CSourceModule. + // The code content is initialized with ";" to prevent complaining + // from CSourceModuleNode::SaveToFile. + ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); + } } else { ret_.mod = tvm::build( lowered_funcs, target_host_, BuildConfig::Current()); } + Array ext_mods = graph_codegen_->GetExternalModules(); - if (!ext_mods.empty()) { - CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1) - << "Expect to have a TVM DSOModule when multiple external runtime modules exist"; - if (lowered_funcs.size() == 0) { - // Execute the whole module using external runtime. - ret_.mod = ext_mods[0]; - } else { - // Import all external runtime modules. - for (const auto& it : ext_mods) { - ret_.mod.Import(it); + // Import all external runtime modules. + for (const auto& it : ext_mods) + ret_.mod.Import(it); + } + + private: + Target GetTargetHost() { + Target target_host = target_host_; + if (!target_host_.defined()) { + for (const auto &it : targets_) { + if (it.second->device_type == kDLCPU) { + target_host = it.second; + break; } } } + return target_host; } protected: diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 2e04920d866b..a4234d0232da 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -356,6 +356,28 @@ TVM_REGISTER_GLOBAL("codegen.build_llvm") *rv = runtime::Module(n); }); +TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") +.set_body([](TVMArgs args, TVMRetValue *rv) { + auto n = make_object(); + auto target = args[0].operator std::string(); + auto module_name = args[1].operator std::string(); + + // Generate a LLVM module from an input target string + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + auto ctx = std::make_shared(); + std::unique_ptr module(new llvm::Module(module_name, *ctx)); + + // Use a default data layout and target triple + auto triple = tm->getTargetTriple(); + module->setTargetTriple(triple.str()); + module->setDataLayout(tm->createDataLayout()); + + n->Init(std::move(module), ctx); + + *rv = runtime::Module(n); +}); + TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = static_cast(LookupLLVMIntrinsic(args[0]));