Skip to content

Commit

Permalink
[CodeGen] Generate blob use LLVM directly
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene committed Jan 8, 2020
1 parent bc0274d commit 4f1772c
Show file tree
Hide file tree
Showing 14 changed files with 534 additions and 20 deletions.
3 changes: 3 additions & 0 deletions cmake/util/FindLLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,8 @@ macro(find_llvm use_llvm)
message(STATUS "Found LLVM_INCLUDE_DIRS=" ${LLVM_INCLUDE_DIRS})
message(STATUS "Found LLVM_DEFINITIONS=" ${LLVM_DEFINITIONS})
message(STATUS "Found TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
if (${TVM_LLVM_VERSION} LESS 40)
message(FATAL_ERROR "TVM requires LLVM 4.0 or higher.")
endif()
endif()
endmacro(find_llvm)
19 changes: 19 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,23 @@ class Target : public ObjectRef {
*/
TVM_DLL static tvm::Target Current(bool allow_not_defined = true);

/*!
* \brief Get the target host based on LLVM.
* \return The target host based on LLVM.
*/
TVM_DLL static tvm::Target GetLLVMTargetHost() {
return llvm_target_host_;
}

/*!
* \brief Set the target host based on LLVM.
* \param llvm_target_host The target host value based on LLVM to be set
*/
TVM_DLL static void SetLLVMTargetHost(const Target& llvm_target_host) {
CHECK(llvm_target_host->str().find("llvm") != std::string::npos);
llvm_target_host_ = llvm_target_host;
}

const TargetNode* operator->() const {
return static_cast<const TargetNode*>(get());
}
Expand All @@ -130,6 +147,8 @@ class Target : public ObjectRef {
* restoring the previous target as the current context.
*/
TVM_DLL void ExitWithScope();

static Target llvm_target_host_;
};

/*! \brief This namespace provides functions to construct Target instances */
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m, bool system_lib);

/*!
* \brief Pack imported device library to a LLVM module.
* Compile the LLVM module and link with the host library
* will allow the DSO loader to automatically discover and import
* the dependency from the shared library.
*
* \param m The host module with the imports.
* \param system_lib Whether expose as system library.
* \param target LLVM target
* \return runtime::Module The generated LLVM module.
*/
runtime::Module PackImportsToLLVM(const runtime::Module& m,
bool system_lib,
const std::string& target);
} // namespace codegen
} // namespace tvm

Expand Down
4 changes: 4 additions & 0 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,10 @@ def build(inputs,
if not target_host:
target_host = "llvm" if module.enabled("llvm") else "stackvm"

# set the target host based on llvm
if "llvm" in str(target_host):
_target.set_llvm_target_host(target_host)

fhost_all = []
device_modules = []
for tar, flist in target_flist.items():
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import struct
from collections import namedtuple

from tvm import target as _target
from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api
from ._ffi.libinfo import find_include_path
Expand Down Expand Up @@ -140,10 +140,16 @@ def export_library(self,
module.get_function("__tvm_is_system_module")())

if self.imported_modules:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)
if enabled("llvm"):
path_obj = temp.relpath("devc.o")
m = _PackImportsToLLVM(self, is_system_lib, str(_target.get_llvm_target_host()))
m.save(path_obj)
files.append(path_obj)
else:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
files.append(path_cc)
if not fcompile:
if file_name.endswith(".tar"):
fcompile = _tar.tar
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,25 @@ def current_target(allow_none=True):
ValueError if current target is not set.
"""
return _api_internal._GetCurrentTarget(allow_none)

def get_llvm_target_host():
"""Returns the target host based on LLVM.
Returns
-------
target : tvm.target.Target
The target object
"""
return _api_internal._GetLLVMTargetHost()

def set_llvm_target_host(target_host):
"""Set the target host based on LLVM.
Parameters
----------
target_host : str or tvm.target.Target
Set the target host
"""
assert isinstance(target_host, (str, Target))
# create(target_host) make sure we pass the Target object to C++ API.
_api_internal._SetLLVMTargetHost(create(target_host))
3 changes: 3 additions & 0 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,8 @@ TVM_REGISTER_GLOBAL("codegen._Build")

TVM_REGISTER_GLOBAL("module._PackImportsToC")
.set_body_typed(PackImportsToC);

TVM_REGISTER_GLOBAL("module._PackImportsToLLVM")
.set_body_typed(PackImportsToLLVM);
} // namespace codegen
} // namespace tvm
17 changes: 16 additions & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ std::string GetDeviceName(const std::string& target_str) {

return "";
}

// Initialize static member of target host
Target Target::llvm_target_host_;
Target Target::Create(const std::string& target_str) {
if (target_str.length() == 0) {
LOG(ERROR) << "target_str must not be empty";
Expand Down Expand Up @@ -575,6 +576,10 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
target_host_val = DefaultTargetHost(target_host_val);
}

if (target_host_val->str().find("llvm") != std::string::npos) {
Target::SetLLVMTargetHost(target_host_val);
}

for (const auto& it : inputs) {
auto host_dev_funcs =
split_dev_host_funcs(it.second, it.first, target_host_val, config);
Expand Down Expand Up @@ -882,6 +887,16 @@ TVM_REGISTER_GLOBAL("_GetCurrentTarget")
*ret = Target::Current(allow_not_defined);
});

TVM_REGISTER_GLOBAL("_GetLLVMTargetHost")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Target::GetLLVMTargetHost();
});

TVM_REGISTER_GLOBAL("_SetLLVMTargetHost")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Target::SetLLVMTargetHost(args[0]);
});

class Target::Internal {
public:
static void EnterScope(Target target) {
Expand Down
46 changes: 40 additions & 6 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <cstdint>
#include <unordered_set>
#include <cstring>
#include <iomanip>

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -158,13 +159,21 @@ class ModuleSerializer {
std::vector<uint64_t> import_tree_child_indices_;
};

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
namespace {
std::string SerializeModule(const runtime::Module& mod) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);
return bin;
}
} // namespace

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin = SerializeModule(mod);

// translate to C program
std::ostringstream os;
Expand Down Expand Up @@ -211,5 +220,30 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "#endif\n";
return os.str();
}

runtime::Module PackImportsToLLVM(const runtime::Module& mod,
bool system_lib,
const std::string& target) {
std::string bin = SerializeModule(mod);

std::ostringstream os;
uint64_t nbytes = bin.length();
os << std::hex;
for (size_t i = 0; i < sizeof(nbytes); ++i) {
os << std::setfill('0') << std::setw(2) << ((nbytes >> (i * 8)) & 0xffUL);
}
for (size_t i = 0; i < bin.length(); ++i) {
int c = bin[i];
os << std::setfill('0') << std::setw(2) << (c & 0xff);
}

// Call codegen_blob to generate LLVM module
std::string codegen_f_name = "codegen.codegen_blob";
// the codegen function.
const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name);
CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented.";
return (*codegen_f)(os.str(), system_lib, target);
}

} // namespace codegen
} // namespace tvm
Loading

0 comments on commit 4f1772c

Please sign in to comment.