Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Support standardize runtime module #4532

Merged
merged 1 commit into from
Dec 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/_ffi/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(self, handle):
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))

def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value

@property
def entry_func(self):
"""Get the entry function
Expand Down
64 changes: 41 additions & 23 deletions python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,28 @@ def export_library(self,
self.save(file_name)
return

if not (self.type_key == "llvm" or self.type_key == "c"):
raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key)
modules = self._collect_dso_modules()
temp = _util.tempdir()
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
else:
if self.type_key == "llvm":
object_format = "o"
files = []
is_system_lib = False
has_c_module = False
for index, module in enumerate(modules):
if fcompile is not None and hasattr(fcompile, "object_format"):
object_format = fcompile.object_format
else:
assert self.type_key == "c"
object_format = "cc"
path_obj = temp.relpath("lib." + object_format)
self.save(path_obj)
files = [path_obj]
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
has_imported_c_file = False
if module.type_key == "llvm":
object_format = "o"
else:
assert module.type_key == "c"
object_format = "cc"
has_c_module = True
path_obj = temp.relpath("lib" + str(index) + "." + object_format)
module.save(path_obj)
files.append(path_obj)
is_system_lib = (module.type_key == "llvm" and
module.get_function("__tvm_is_system_module")())

if self.imported_modules:
for i, m in enumerate(self.imported_modules):
if m.type_key == "c":
has_imported_c_file = True
c_file_name = "tmp_" + str(i) + ".cc"
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
f.write(m.get_source())
files.append(path_cc)
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
Expand All @@ -152,13 +149,15 @@ def export_library(self,
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if self.type_key == "c" or has_imported_c_file:

if has_c_module:
options = []
if "options" in kwargs:
opts = kwargs["options"]
options = opts if isinstance(opts, (list, tuple)) else [opts]
opts = options + ["-I" + path for path in find_include_path()]
kwargs.update({'options': opts})

fcompile(file_name, files, **kwargs)

def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
Expand Down Expand Up @@ -219,6 +218,25 @@ def evaluator(*args):
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")

def _collect_dso_modules(self):
"""Helper function to collect dso modules, then return it."""
visited, stack, dso_modules = set(), [], []
# append root module
visited.add(self)
stack.append(self)
while stack:
module = stack.pop()
if module._dso_exportable():
dso_modules.append(module)
for m in module.imported_modules:
if m not in visited:
visited.add(m)
stack.append(m)
return dso_modules

def _dso_exportable(self):
return self.type_key == "llvm" or self.type_key == "c"


def system_lib():
"""Get system-wide library module singleton.
Expand Down
116 changes: 105 additions & 11 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <iostream>
#include <vector>
#include <cstdint>
#include <unordered_set>
#include <cstring>

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -58,20 +61,111 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
return m;
}

/*! \brief Helper class to serialize module */
class ModuleSerializer {
public:
explicit ModuleSerializer(runtime::Module mod) : mod_(mod) {
Init();
}

void SerializeModule(dmlc::Stream* stream) {
// Only have one DSO module and it is in the root, then
// we will not produce import_tree_.
bool has_import_tree = true;
if (DSOExportable(mod_.operator->()) && mod_->imports().empty()) {
has_import_tree = false;
}
uint64_t sz = 0;
if (has_import_tree) {
// we will append one key for _import_tree
// The layout is the same as before: binary_size, key, logic, key, logic...
sz = mod_vec_.size() + 1;
} else {
// Keep the old behaviour
sz = mod_->imports().size();
}
stream->Write(sz);

for (auto m : mod_vec_) {
std::string mod_type_key = m->type_key();
if (!DSOExportable(m)) {
stream->Write(mod_type_key);
m->SaveToBinary(stream);
} else if (has_import_tree) {
mod_type_key = "_lib";
stream->Write(mod_type_key);
}
}

// Write _import_tree key if we have
if (has_import_tree) {
std::string import_key = "_import_tree";
stream->Write(import_key);
stream->Write(import_tree_row_ptr_);
stream->Write(import_tree_child_indices_);
}
}

private:
void Init() {
CreateModuleIndex();
CreateImportTree();
}

// invariance: root module is always at location 0.
// The module order is collected via DFS
void CreateModuleIndex() {
std::unordered_set<const runtime::ModuleNode*> visited {mod_.operator->()};
std::vector<runtime::ModuleNode*> stack {mod_.operator->()};
uint64_t module_index = 0;

while (!stack.empty()) {
runtime::ModuleNode* n = stack.back();
stack.pop_back();
mod2index_[n] = module_index++;
mod_vec_.emplace_back(n);
for (runtime::Module m : n->imports()) {
runtime::ModuleNode* next = m.operator->();
if (visited.count(next) == 0) {
visited.insert(next);
stack.push_back(next);
}
}
}
}

void CreateImportTree() {
for (auto m : mod_vec_) {
for (runtime::Module im : m->imports()) {
uint64_t mod_index = mod2index_[im.operator->()];
import_tree_child_indices_.push_back(mod_index);
}
import_tree_row_ptr_.push_back(import_tree_child_indices_.size());
}
}

bool DSOExportable(const runtime::ModuleNode* mod) {
return !std::strcmp(mod->type_key(), "llvm") ||
!std::strcmp(mod->type_key(), "c");
}

runtime::Module mod_;
// construct module to index
std::unordered_map<runtime::ModuleNode*, size_t> mod2index_;
// index -> module
std::vector<runtime::ModuleNode*> mod_vec_;
std::vector<uint64_t> import_tree_row_ptr_ {0};
std::vector<uint64_t> import_tree_child_indices_;
};

std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
uint64_t sz = static_cast<uint64_t>(mod->imports().size());
stream->Write(sz);
for (runtime::Module im : mod->imports()) {
CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
stream->Write(tkey);
if (tkey == "c") continue;
im->SaveToBinary(stream);
}

ModuleSerializer module_serializer(mod);
module_serializer.SerializeModule(stream);

// translate to C program
std::ostringstream os;
os << "#ifdef _WIN32\n"
Expand Down
66 changes: 54 additions & 12 deletions src/runtime/library_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/registry.h>
#include <string>
#include <vector>
#include <cstdint>
#include "library_module.h"

namespace tvm {
Expand Down Expand Up @@ -108,9 +109,11 @@ void InitContextFunctions(std::function<void*(const char*)> fgetsymbol) {
/*!
* \brief Load and append module blob to module list
* \param mblob The module blob.
* \param module_list The module list to append to
* \param lib The library.
*
* \return Root Module.
*/
void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr<Library> lib) {
#ifndef _LIBCPP_SGX_CONFIG
CHECK(mblob != nullptr);
uint64_t nbytes = 0;
Expand All @@ -123,20 +126,56 @@ void ImportModuleBlob(const char* mblob, std::vector<Module>* mlist) {
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
std::vector<Module> modules;
std::vector<uint64_t> import_tree_row_ptr;
std::vector<uint64_t> import_tree_child_indices;
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
if (tkey == "c") continue;
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
// Currently, _lib is for DSOModule, but we
// don't have loadbinary function for it currently
if (tkey == "_lib") {
auto dso_module = Module(make_object<LibraryModuleNode>(lib));
modules.emplace_back(dso_module);
} else if (tkey == "_import_tree") {
CHECK(stream->Read(&import_tree_row_ptr));
CHECK(stream->Read(&import_tree_child_indices));
} else {
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
mlist->push_back(m);
Module m = (*f)(static_cast<void*>(stream));
modules.emplace_back(m);
}
}
// if we are using old dll, we don't have import tree
// so that we can't reconstruct module relationship using import tree
if (import_tree_row_ptr.empty()) {
auto n = make_object<LibraryModuleNode>(lib);
auto module_import_addr = ModuleInternal::GetImportsAddr(n.operator->());
for (const auto& m : modules) {
module_import_addr->emplace_back(m);
}
return Module(n);
} else {
for (size_t i = 0; i < modules.size(); ++i) {
for (size_t j = import_tree_row_ptr[i]; j < import_tree_row_ptr[i + 1]; ++j) {
auto module_import_addr = ModuleInternal::GetImportsAddr(modules[i].operator->());
auto child_index = import_tree_child_indices[j];
CHECK(child_index < modules.size());
module_import_addr->emplace_back(modules[child_index]);
}
}
}
CHECK(!modules.empty());
// invariance: root module is always at location 0.
// The module order is collected via DFS
return modules[0];
FrozenGene marked this conversation as resolved.
Show resolved Hide resolved
#else
LOG(FATAL) << "SGX does not support ImportModuleBlob";
return Module();
#endif
}

Expand All @@ -149,17 +188,20 @@ Module CreateModuleFromLibrary(ObjectPtr<Library> lib) {
const char* dev_mblob =
reinterpret_cast<const char*>(
lib->GetSymbol(runtime::symbol::tvm_dev_mblob));
Module root_mod;
if (dev_mblob != nullptr) {
ImportModuleBlob(
dev_mblob, ModuleInternal::GetImportsAddr(n.operator->()));
root_mod = ProcessModuleBlob(dev_mblob, lib);
} else {
// Only have one single DSO Module
root_mod = Module(n);
}

Module root_mod = Module(n);
// allow lookup of symbol from root(so all symbols are visible).
// allow lookup of symbol from root (so all symbols are visible).
if (auto *ctx_addr =
reinterpret_cast<void**>(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = root_mod.operator->();
}

return root_mod;
}
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
if (it != import_cache_.end()) return it->second.get();
PackedFunc pf;
for (Module& m : this->imports_) {
pf = m.GetFunction(name, false);
pf = m.GetFunction(name, true);
if (pf != nullptr) break;
}
if (pf == nullptr) {
Expand Down
Loading