From f354a002887dc5dbc058bc5ad9d26e3d16526e1e Mon Sep 17 00:00:00 2001 From: FrozenGene Date: Wed, 8 Jan 2020 19:32:52 +0800 Subject: [PATCH] [CodeGen] Generate blob use LLVM directly --- cmake/util/FindLLVM.cmake | 3 + include/tvm/codegen.h | 15 ++ python/tvm/contrib/cc.py | 40 ++++- python/tvm/contrib/ndk.py | 4 + python/tvm/module.py | 24 ++- src/api/api_codegen.cc | 3 + src/codegen/codegen.cc | 45 +++++- src/codegen/llvm/codegen_blob.cc | 162 +++++++++++++++++++++ src/codegen/llvm/codegen_blob.h | 51 +++++++ src/codegen/llvm/llvm_module.cc | 42 +++++- tests/python/unittest/test_codegen_blob.py | 104 +++++++++++++ 11 files changed, 472 insertions(+), 21 deletions(-) create mode 100644 src/codegen/llvm/codegen_blob.cc create mode 100644 src/codegen/llvm/codegen_blob.h create mode 100644 tests/python/unittest/test_codegen_blob.py diff --git a/cmake/util/FindLLVM.cmake b/cmake/util/FindLLVM.cmake index 7e759ab20037..e50c7d0b01d6 100644 --- a/cmake/util/FindLLVM.cmake +++ b/cmake/util/FindLLVM.cmake @@ -95,5 +95,8 @@ macro(find_llvm use_llvm) message(STATUS "Found LLVM_INCLUDE_DIRS=" ${LLVM_INCLUDE_DIRS}) message(STATUS "Found LLVM_DEFINITIONS=" ${LLVM_DEFINITIONS}) message(STATUS "Found TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) + if (${TVM_LLVM_VERSION} LESS 40) + message(FATAL_ERROR "TVM requires LLVM 4.0 or higher.") + endif() endif() endmacro(find_llvm) diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h index 78fb7d15b6f9..218a7827ba1e 100644 --- a/include/tvm/codegen.h +++ b/include/tvm/codegen.h @@ -59,6 +59,21 @@ runtime::Module Build(const Array& funcs, * \return cstr The C string representation of the file. */ std::string PackImportsToC(const runtime::Module& m, bool system_lib); + +/*! + * \brief Pack imported device library to a LLVM module. + * Compile the LLVM module and link with the host library + * will allow the DSO loader to automatically discover and import + * the dependency from the shared library. + * + * \param m The host module with the imports. + * \param system_lib Whether expose as system library. + * \param target_triple LLVM target triple + * \return runtime::Module The generated LLVM module. + */ +runtime::Module PackImportsToLLVM(const runtime::Module& m, + bool system_lib, + const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 1550d5abf989..0c836144577b 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -51,10 +51,41 @@ def create_shared(output, else: raise ValueError("Unsupported platform") +def get_target_by_dump_machine(compiler): + """ Functor of get_target_triple that can get the target triple using compiler. + + Parameters + ---------- + compiler : Optional[str] + The compiler. + + Returns + ------- + out: Callable + A function that can get target triple according to dumpmachine option of compiler. + """ + def get_target_triple(): + """ Get target triple according to dumpmachine option of compiler.""" + if compiler: + cmd = [compiler, "-dumpmachine"] + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "dumpmachine error:\n" + msg += py_str(out) + return None + return py_str(out) + else: + return None + + return get_target_triple + # assign so as default output format create_shared.output_format = "so" if sys.platform != "win32" else "dll" - +create_shared.get_target_triple = get_target_by_dump_machine( + "g++" if sys.platform == "darwin" or sys.platform.startswith("linux") else None) def build_create_shared_func(options=None, compile_cmd="g++"): """Build create_shared function with particular default options and compile_cmd. @@ -75,10 +106,11 @@ def build_create_shared_func(options=None, compile_cmd="g++"): def create_shared_wrapper(output, objects, options=options, compile_cmd=compile_cmd): create_shared(output, objects, options, compile_cmd) create_shared_wrapper.output_format = create_shared.output_format + create_shared_wrapper.get_target_triple = get_target_by_dump_machine(compile_cmd) return create_shared_wrapper -def cross_compiler(compile_func, base_options=None, output_format="so"): +def cross_compiler(compile_func, base_options=None, output_format="so", get_target_triple=None): """Create a cross compiler function. Parameters @@ -92,6 +124,9 @@ def cross_compiler(compile_func, base_options=None, output_format="so"): output_format : Optional[str] Library output format. + get_target_triple: Optional[Callable] + Function that can target triple according to dumpmachine option of compiler. + Returns ------- fcompile : Callable[[str, str, Optional[str]], None] @@ -105,6 +140,7 @@ def _fcompile(outputs, objects, options=None): all_options += options compile_func(outputs, objects, options=all_options) _fcompile.output_format = output_format + _fcompile.get_target_triple = get_target_triple return _fcompile diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py index e1703ce03f8e..bada95ff0cdd 100644 --- a/python/tvm/contrib/ndk.py +++ b/python/tvm/contrib/ndk.py @@ -21,6 +21,7 @@ import subprocess import os from .._ffi.base import py_str +from .cc import get_target_by_dump_machine def create_shared(output, objects, @@ -64,5 +65,8 @@ def create_shared(output, msg += py_str(out) raise RuntimeError(msg) + # assign output format create_shared.output_format = "so" +create_shared.get_target_triple = get_target_by_dump_machine( + os.environ["TVM_NDK_CC"]) if "TVM_NDK_CC" in os.environ else None diff --git a/python/tvm/module.py b/python/tvm/module.py index e9e229469831..f8ad0a447941 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -123,6 +123,7 @@ def export_library(self, files = [] is_system_lib = False has_c_module = False + llvm_target_triple = None for index, module in enumerate(modules): if fcompile is not None and hasattr(fcompile, "object_format"): object_format = fcompile.object_format @@ -138,18 +139,29 @@ def export_library(self, files.append(path_obj) is_system_lib = (module.type_key == "llvm" and module.get_function("__tvm_is_system_module")()) - - if self.imported_modules: - path_cc = temp.relpath("devc.cc") - with open(path_cc, "w") as f: - f.write(_PackImportsToC(self, is_system_lib)) - files.append(path_cc) + llvm_target_triple = (module.type_key == "llvm" and + module.get_function("_get_target_triple")()) if not fcompile: if file_name.endswith(".tar"): fcompile = _tar.tar else: fcompile = _cc.create_shared + if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"): + llvm_target_triple = fcompile.get_target_triple() + + if self.imported_modules: + if enabled("llvm") and llvm_target_triple: + path_obj = temp.relpath("devc.o") + m = _PackImportsToLLVM(self, is_system_lib, llvm_target_triple) + m.save(path_obj) + files.append(path_obj) + else: + path_cc = temp.relpath("devc.cc") + with open(path_cc, "w") as f: + f.write(_PackImportsToC(self, is_system_lib)) + files.append(path_cc) + if has_c_module: options = [] if "options" in kwargs: diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index a58e905aff13..1d997a2ae093 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -43,5 +43,8 @@ TVM_REGISTER_GLOBAL("codegen._Build") TVM_REGISTER_GLOBAL("module._PackImportsToC") .set_body_typed(PackImportsToC); + +TVM_REGISTER_GLOBAL("module._PackImportsToLLVM") +.set_body_typed(PackImportsToLLVM); } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index 60b12dc6e553..a038d4c56bb6 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -158,13 +159,21 @@ class ModuleSerializer { std::vector import_tree_child_indices_; }; -std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { - std::string bin; - dmlc::MemoryStringStream ms(&bin); - dmlc::Stream* stream = &ms; +namespace { + std::string SerializeModule(const runtime::Module& mod) { + std::string bin; + dmlc::MemoryStringStream ms(&bin); + dmlc::Stream* stream = &ms; + + ModuleSerializer module_serializer(mod); + module_serializer.SerializeModule(stream); + + return bin; + } +} // namespace - ModuleSerializer module_serializer(mod); - module_serializer.SerializeModule(stream); +std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { + std::string bin = SerializeModule(mod); // translate to C program std::ostringstream os; @@ -211,5 +220,29 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { << "#endif\n"; return os.str(); } + +runtime::Module PackImportsToLLVM(const runtime::Module& mod, + bool system_lib, + const std::string& target_triple) { + std::string bin = SerializeModule(mod); + + uint64_t nbytes = bin.length(); + std::string header; + for (size_t i = 0; i < sizeof(nbytes); ++i) { + header.push_back(((nbytes >> (i * 8)) & 0xffUL)); + } + std::string blob = header + bin; + TVMByteArray blob_byte_array; + blob_byte_array.size = blob.length(); + blob_byte_array.data = blob.data(); + + // Call codegen_blob to generate LLVM module + std::string codegen_f_name = "codegen.codegen_blob"; + // the codegen function. + const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name); + CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; + return (*codegen_f)(blob_byte_array, system_lib, target_triple); +} + } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/codegen_blob.cc b/src/codegen/llvm/codegen_blob.cc new file mode 100644 index 000000000000..be8ef9262765 --- /dev/null +++ b/src/codegen/llvm/codegen_blob.cc @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_blob.cc + */ +#ifdef TVM_LLVM_VERSION +#include +#include +#include "codegen_blob.h" + +namespace tvm { +namespace codegen { + +std::pair, + std::shared_ptr> CodeGenBlob(const std::string& data, + bool system_lib, + const std::string& target_triple) { + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(std::string("-target ") + target_triple); + auto triple = tm->getTargetTriple(); + auto ctx = std::make_shared(); + std::string module_name = "devc"; + std::unique_ptr module(new llvm::Module(module_name, *ctx)); + module->setTargetTriple(triple.str()); + module->setDataLayout(tm->createDataLayout()); + auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); + auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true, + llvm::GlobalValue::ExternalLinkage, blob_value, + runtime::symbol::tvm_dev_mblob, nullptr, + llvm::GlobalVariable::NotThreadLocal, 0); + +#if TVM_LLVM_VERSION >= 100 + tvm_dev_mblob->setAlignment(llvm::Align(1)); +#else + tvm_dev_mblob->setAlignment(1); +#endif + + if (triple.isOSWindows()) { + tvm_dev_mblob->setDLLStorageClass(llvm::GlobalVariable::DLLExportStorageClass); + } + + if (system_lib) { + // LLVM type helper + auto void_ty = llvm::Type::getVoidTy(*ctx); + auto int32_ty = llvm::Type::getInt32Ty(*ctx); + auto int8_ty = llvm::Type::getInt8Ty(*ctx); + auto int8_ptr_ty = int8_ty->getPointerTo(0); + + llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty); + auto* tvm_dev_mblob_reg = + new llvm::GlobalVariable(*module, int32_ty, + false, llvm::GlobalValue::InternalLinkage, + constant_zero, + std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); + auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty); +#if TVM_LLVM_VERSION >= 100 + tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment)); +#else + tvm_dev_mblob_reg->setAlignment(tvm_dev_mblob_reg_alignment); +#endif + + auto* tvm_dev_mblob_string_ty = + llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1); + auto* tvm_dev_mblob_string_value = + llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true); + auto* tvm_dev_mblob_string = + new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty, + true, llvm::GlobalValue::PrivateLinkage, + tvm_dev_mblob_string_value, + std::string(runtime::symbol::tvm_dev_mblob) + ".str"); +#if TVM_LLVM_VERSION >= 100 + tvm_dev_mblob_string->setAlignment(llvm::Align(1)); +#else + tvm_dev_mblob_string->setAlignment(1); +#endif + + // Global init function + llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), + llvm::GlobalValue::InternalLinkage, + llvm::Twine("_GLOBAL__sub_I_", module_name), + module.get()); + + // Create variable initialization function. + llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), + llvm::GlobalValue::InternalLinkage, + llvm::Twine("__cxx_global_var_init"), + module.get()); + + // Create TVMBackendRegisterSystemLibSymbol function + llvm::Function* tvm_backend_fn = + llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), + llvm::GlobalValue::ExternalLinkage, + llvm::Twine("TVMBackendRegisterSystemLibSymbol"), + module.get()); + + // Set necessary fn sections + auto get_static_init_section_specifier = [&triple]() -> std::string { + if (triple.isOSLinux()) { + return ".text.startup"; + } else if (triple.isOSDarwin()) { + return "__TEXT,__StaticInit,regular,pure_instructions"; + } else { + return ""; + } + }; + + auto static_init_section_specifier = get_static_init_section_specifier(); + + if (!static_init_section_specifier.empty()) { + init_fn->setSection(static_init_section_specifier); + var_init_fn->setSection(static_init_section_specifier); + } + + // The priority is 65535 for all platforms as clang do. + llvm::appendToGlobalCtors(*module, init_fn, 65535); + + // Define init_fn body + llvm::IRBuilder<> ir_builder(*ctx); + llvm::BasicBlock* init_fn_bb = llvm::BasicBlock::Create(*ctx, "entry", init_fn); + ir_builder.SetInsertPoint(init_fn_bb); + ir_builder.CreateCall(var_init_fn); + ir_builder.CreateRetVoid(); + + // Define var_init_fn body + llvm::BasicBlock* var_init_fn_bb = llvm::BasicBlock::Create(*ctx, "entry", var_init_fn); + ir_builder.SetInsertPoint(var_init_fn_bb); + llvm::Constant* indices[] = {constant_zero, constant_zero}; + llvm::SmallVector args; + args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty, + tvm_dev_mblob_string, + indices)); + args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), + tvm_dev_mblob, + indices)); + auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args); + ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg); + ir_builder.CreateRetVoid(); + } + + return std::make_pair(std::move(module), ctx); +} + +} // namespace codegen +} // namespace tvm +#endif // TVM_LLVM_VERSION diff --git a/src/codegen/llvm/codegen_blob.h b/src/codegen/llvm/codegen_blob.h new file mode 100644 index 000000000000..79c0d385cfbf --- /dev/null +++ b/src/codegen/llvm/codegen_blob.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_blob.h + * \brief Code Generation of blob data + */ +#ifndef TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_ +#define TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_ +#ifdef TVM_LLVM_VERSION +#include +#include +#include +#include "llvm_common.h" + +namespace tvm { +namespace codegen { +/** + * \brief Code Generation of blob data + * + * \param data Blob data + * \param system_lib Whether expose as system library. + * \param target_triple LLVM target triple + * + * \return LLVM module and LLVM context + */ +std::pair, + std::shared_ptr> CodeGenBlob(const std::string& data, + bool system_lib, + const std::string& target_triple); + +} // namespace codegen +} // namespace tvm +#endif // LLVM_VERSION +#endif // TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_ diff --git a/src/codegen/llvm/llvm_module.cc b/src/codegen/llvm/llvm_module.cc index 933e78fc4f41..32f5451ff014 100644 --- a/src/codegen/llvm/llvm_module.cc +++ b/src/codegen/llvm/llvm_module.cc @@ -29,6 +29,7 @@ #include #include "llvm_common.h" #include "codegen_llvm.h" +#include "codegen_blob.h" #include "../../runtime/file_util.h" #include "../../runtime/library_module.h" @@ -62,6 +63,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) { * rv = flag; }); + } else if (name == "_get_target_triple") { + std::string target_triple = tm_->getTargetTriple().str(); + return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) { + * rv = target_triple; + }); } if (ee_ == nullptr) LazyInitJIT(); std::lock_guard lock(mutex_); @@ -218,15 +224,15 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_ = module_.get(); } - void LoadIR(const std::string& file_name) { + void Init(std::unique_ptr module, + std::shared_ptr ctx) { InitializeLLVM(); - ctx_ = std::make_shared(); + ctx_ = ctx; llvm::SMDiagnostic err; - module_ = llvm::parseIRFile(file_name, err, *ctx_); - if (module_.get() == nullptr) { + module_ = std::move(module); + if (module_ == nullptr) { std::string msg = err.getMessage(); - LOG(FATAL) << "Fail to load ir file " << file_name << "\n" - << "line " << err.getLineNo() << ":" << msg; + LOG(FATAL) << "Fail to load module: " << msg; } std::string target_; llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target"); @@ -243,6 +249,18 @@ class LLVMModuleNode final : public runtime::ModuleNode { tm_ = GetLLVMTargetMachine(target_); } + void LoadIR(const std::string& file_name) { + auto ctx = std::make_shared(); + llvm::SMDiagnostic err; + auto module = llvm::parseIRFile(file_name, err, *ctx); + if (module == nullptr) { + std::string msg = err.getMessage(); + LOG(FATAL) << "Fail to load ir file " << file_name << "\n" + << "line " << err.getLineNo() << ":" << msg; + } + Init(std::move(module), ctx); + } + private: void LazyInitJIT() { std::lock_guard lock(mutex_); @@ -339,7 +357,7 @@ TVM_REGISTER_GLOBAL("codegen.llvm_lookup_intrinsic_id") TVM_REGISTER_GLOBAL("codegen.build_llvm") .set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); - n->Init(args[0], args[1]); + n->Init(args[0].operator Array(), args[1].operator std::string()); *rv = runtime::Module(n); }); @@ -362,6 +380,16 @@ TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") InitializeLLVM(); *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); }); + +TVM_REGISTER_GLOBAL("codegen.codegen_blob") +.set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + auto p = CodeGenBlob(args[0].operator std::string(), + args[1].operator bool(), + args[2].operator std::string()); + n->Init(std::move(p.first), p.second); + *rv = runtime::Module(n); +}); } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/tests/python/unittest/test_codegen_blob.py b/tests/python/unittest/test_codegen_blob.py new file mode 100644 index 000000000000..1d715ba68264 --- /dev/null +++ b/tests/python/unittest/test_codegen_blob.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +from tvm import relay +from tvm.relay import testing +from tvm.contrib import graph_runtime +import tvm +import ctypes + +def test_resnet18(): + for device in ["llvm", "cuda"]: + if not tvm.module.enabled(device): + print("skip because %s is not enabled..." % device) + return + + def verify(data): + mod, params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.set_input("data", data) + module.set_input(**graph_params) + module.run() + out = module.get_output(0).asnumpy() + return out + + resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) + with relay.build_config(opt_level=3): + graph, resnet18_gpu_lib, graph_params = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params) + + from tvm.contrib import util + temp = util.tempdir() + path_lib = temp.relpath("deploy_lib.so") + resnet18_gpu_lib.export_library(path_lib) + with open(temp.relpath("deploy_graph.json"), "w") as fo: + fo.write(graph) + with open(temp.relpath("deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(graph_params)) + + loaded_lib = tvm.module.load(path_lib) + loaded_json = open(temp.relpath("deploy_graph.json")).read() + loaded_params = bytearray(open(temp.relpath("deploy_param.params"), "rb").read()) + data = np.random.uniform(-1, 1, size=(1, 3, 224, 224)).astype("float32") + ctx = tvm.gpu() + module = graph_runtime.create(loaded_json, loaded_lib, ctx) + module.load_params(loaded_params) + module.set_input("data", data) + module.run() + out = module.get_output(0).asnumpy() + + tvm.testing.assert_allclose(out, verify(data), atol=1e-5) + + +def test_system_lib(): + ctx = tvm.gpu(0) + for device in ["llvm", "cuda"]: + if not tvm.module.enabled(device): + print("skip because %s is not enabled..." % device) + return + nn = 12 + n = tvm.convert(nn) + A = tvm.placeholder((n,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + bx, tx = s[B].split(B.op.axis[0], factor=4) + s[B].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B].bind(tx, tvm.thread_axis("threadIdx.x")) + + from tvm.contrib import util + temp = util.tempdir() + fn_add = tvm.build(s, [A, B], target="cuda", target_host="llvm -system-lib", name="add") + path_obj = temp.relpath("add.o") + path_lib = temp.relpath("deploy_lib.so") + fn_add.save(path_obj) + fn_add.export_library(path_lib) + # Load dll, will trigger system library registration + dll = ctypes.CDLL(path_lib) + # Load the system wide library + m = tvm.module.system_lib() + a = tvm.nd.array(np.random.uniform(size=nn).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=A.dtype), ctx) + m['add'](a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + + +if __name__ == "__main__": + test_resnet18() + test_system_lib()