diff --git a/apps/bundle_deploy/Makefile b/apps/bundle_deploy/Makefile index 57e484379a4e..bd4053f2911c 100644 --- a/apps/bundle_deploy/Makefile +++ b/apps/bundle_deploy/Makefile @@ -17,40 +17,80 @@ # Makefile Example to bundle TVM modules. +# Setup build environment TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++14 -O2 -fPIC\ - -I${TVM_ROOT}/include\ - -I${DMLC_CORE}/include\ +PKG_CXXFLAGS = -std=c++14 -O2 -fPIC \ + -I${TVM_ROOT}/include \ + -I${DMLC_CORE}/include \ + -I${TVM_ROOT}/3rdparty/dlpack/include +PKG_CFLAGS = -std=c99 -O2 -fPIC \ + -I${TVM_ROOT}/include \ + -I${DMLC_CORE}/include \ -I${TVM_ROOT}/3rdparty/dlpack/include PKG_LDFLAGS = -pthread build_dir := build -test: $(build_dir)/demo $(build_dir)/bundle.so - $(build_dir)/demo $(build_dir)/bundle.so +demo: $(build_dir)/demo $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/cat.bin + TVM_NUM_THREADS=1 $(build_dir)/demo $(build_dir)/bundle.so $(build_dir)/cat.bin + TVM_NUM_THREADS=1 $(build_dir)/demo $(build_dir)/bundle_c.so $(build_dir)/cat.bin + +test: $(build_dir)/test $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin + TVM_NUM_THREADS=1 $(build_dir)/test $(build_dir)/test_bundle.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin + TVM_NUM_THREADS=1 $(build_dir)/test $(build_dir)/test_bundle_c.so $(build_dir)/test_data.bin $(build_dir)/test_output.bin $(build_dir)/test_graph.json $(build_dir)/test_params.bin + +$(build_dir)/demo: demo.cc ${build_dir}/graph.json.c ${build_dir}/params.bin.c + @mkdir -p $(@D) + g++ $(PKG_CXXFLAGS) -o $@ demo.cc -ldl -$(build_dir)/demo: demo.cc +$(build_dir)/test: test.cc ${build_dir}/test_graph.json ${build_dir}/test_params.bin @mkdir -p $(@D) - $(CXX) $(PKG_CFLAGS) -o $@ $^ -ldl + g++ $(PKG_CXXFLAGS) -o $@ test.cc -ldl # Serialize our graph.json file. -$(build_dir)/graph.json.cc: $(build_dir)/graph.json +$(build_dir)/graph.json.c: $(build_dir)/graph.json xxd -i $^ > $@ # Serialize our params.bin file. -$(build_dir)/params.bin.cc: $(build_dir)/params.bin +$(build_dir)/params.bin.c: $(build_dir)/params.bin xxd -i $^ > $@ -$(build_dir)/model.o $(build_dir)/graph.json $(build_dir)/params.bin: build_model.py +# # Serialize our test_graph.json file. +# $(build_dir)/test_graph.json.c: $(build_dir)/test_graph.json +# xxd -i $^ > $@ +# +# # Serialize our test_params.bin file. +# $(build_dir)/test_params.bin.c: $(build_dir)/test_params.bin +# xxd -i $^ > $@ + +$(build_dir)/model.o $(build_dir)/graph.json $(build_dir)/params.bin $(build_dir)/cat.bin: build_model.py python3 $< -o $(build_dir) -# Build our bundle against the serialized bundle.cc API, the runtime.cc API, and +$(build_dir)/test_model.o $(build_dir)/test_graph.json $(build_dir)/test_params.bin $(build_dir)/test_data.bin $(build_dir)/test_output.bin: build_model.py + python3 $< -o $(build_dir) --test + +# Build our bundle against the serialized bundle.c API, the runtime.cc API, and # the serialized graph.json and params.bin -$(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model.o $(build_dir)/graph.json.cc $(build_dir)/params.bin.cc +$(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model.o @mkdir -p $(@D) - $(CXX) -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) + g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) + +$(build_dir)/bundle_c.so: bundle.c runtime.c $(build_dir)/model.o + @mkdir -p $(@D) + gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) + +$(build_dir)/test_bundle.so: bundle.cc runtime.cc $(build_dir)/test_model.o + @mkdir -p $(@D) + g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) + +$(build_dir)/test_bundle_c.so: bundle.c runtime.c $(build_dir)/test_model.o + @mkdir -p $(@D) + gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) clean: - rm -r $(build_dir) + rm -rf $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/test_bundle.so $(build_dir)/test_bundle_c.so + +cleanall: + rm -rf $(build_dir) diff --git a/apps/bundle_deploy/README.md b/apps/bundle_deploy/README.md index 6be9c4f91340..676ae7d9e6c9 100644 --- a/apps/bundle_deploy/README.md +++ b/apps/bundle_deploy/README.md @@ -45,9 +45,10 @@ make demo This will: - Download the mobilenet0.25 model from the MXNet Gluon Model Zoo -- Compile the model with NNVM +- Compile the model with Relay - Build a `bundle.so` shared object containing the model specification and parameters -- Build a `demo` executable that `dlopen`'s `bundle.so`, instantiates the - contained graph runtime, and invokes the `GraphRuntime::Run` function on a - random input, then prints the output tensor to `stderr`. +- Build a `demo` executable that `dlopen`'s `bundle.so` (or `bundle_c.so` in + terms of the MISRA-C runtime), instantiates the contained graph runtime, + and invokes the `GraphRuntime::Run` function on a cat image, then prints + the output results. diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index 37e302449016..63d658e6d428 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -22,15 +22,9 @@ import tvm from tvm import te import logging +import json - -def main(): - logging.basicConfig(level=logging.INFO) - - parser = argparse.ArgumentParser() - parser.add_argument('-o', '--out-dir', default='.') - opts = parser.parse_args() - +def build_module(opts): dshape = (1, 3, 224, 224) from mxnet.gluon.model_zoo.vision import get_model block = get_model('mobilenet0.25', pretrained=True) @@ -53,6 +47,69 @@ def main(): with open(os.path.join(build_dir, 'params.bin'), 'wb') as f_params: f_params.write(relay.save_param_dict(params)) +def build_test_module(opts): + import numpy as np + + x = relay.var('x', shape=(10, 5)) + y = relay.var('y', shape=(1, 5)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + x_data = np.random.rand(10, 5).astype('float32') + y_data = np.random.rand(1, 5).astype('float32') + params = {"y": y_data} + graph, lib, params = relay.build( + tvm.IRModule.from_expr(func), "llvm --system-lib", params=params) + + build_dir = os.path.abspath(opts.out_dir) + if not os.path.isdir(build_dir): + os.makedirs(build_dir) + + lib.save(os.path.join(build_dir, 'test_model.o')) + with open(os.path.join(build_dir, 'test_graph.json'), 'w') as f_graph_json: + f_graph_json.write(graph) + with open(os.path.join(build_dir, 'test_params.bin'), 'wb') as f_params: + f_params.write(relay.save_param_dict(params)) + with open(os.path.join(build_dir, "test_data.bin"), "wb") as fp: + fp.write(x_data.astype(np.float32).tobytes()) + x_output = x_data + y_data + with open(os.path.join(build_dir, "test_output.bin"), "wb") as fp: + fp.write(x_output.astype(np.float32).tobytes()) + +def build_inputs(opts): + from tvm.contrib import download + from PIL import Image + import numpy as np + + build_dir = os.path.abspath(opts.out_dir) + + # Download test image + image_url = 'https://homes.cs.washington.edu/~moreau/media/vta/cat.jpg' + image_fn = os.path.join(build_dir, "cat.png") + download.download(image_url, image_fn) + image = Image.open(image_fn).resize((224, 224)) + + def transform_image(image): + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image + + x = transform_image(image) + print('x', x.shape) + with open(os.path.join(build_dir, "cat.bin"), "wb") as fp: + fp.write(x.astype(np.float32).tobytes()) if __name__ == '__main__': - main() + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument('-o', '--out-dir', default='.') + parser.add_argument('-t', '--test', action='store_true') + opts = parser.parse_args() + + if opts.test: + build_test_module(opts) + else: + build_module(opts) + build_inputs(opts) diff --git a/apps/bundle_deploy/bundle.c b/apps/bundle_deploy/bundle.c new file mode 100644 index 000000000000..dd24bcbdc049 --- /dev/null +++ b/apps/bundle_deploy/bundle.c @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +/*! \brief macro to do C API call */ +#define TVM_CCALL(func) \ + do { \ + int ret = (func); \ + if (ret != 0) { \ + fprintf(stderr, "%s: %d: error: %s\n", __FILE__, __LINE__, TVMGetLastError()); \ + exit(ret); \ + } \ + } while (0) + +TVM_DLL void * tvm_runtime_create(const char * json_data, + const char * params_data, + const uint64_t params_size) { + int64_t device_type = kDLCPU; + int64_t device_id = 0; + + TVMByteArray params; + params.data = params_data; + params.size = params_size; + + TVMContext ctx; + ctx.device_type = (DLDeviceType)device_type; + ctx.device_id = device_id; + + // declare pointers + TVMModuleHandle (*SystemLibraryCreate)(); + TVMModuleHandle (*TVMGraphRuntimeCreate)(const char *, const TVMModuleHandle, const TVMContext *); + int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char *, const uint32_t); + + // get pointers + TVM_CCALL(TVMFuncGetGlobal("runtime.SystemLib", (TVMFunctionHandle*)&SystemLibraryCreate)); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.create", (TVMFunctionHandle*)&TVMGraphRuntimeCreate)); + + // run modules + TVMModuleHandle mod_syslib = SystemLibraryCreate(); + TVMModuleHandle mod = TVMGraphRuntimeCreate(json_data, mod_syslib, &ctx); + TVM_CCALL(TVMModGetFunction(mod, "load_params", 0, (TVMFunctionHandle*)&TVMGraphRuntime_LoadParams)); + TVMGraphRuntime_LoadParams(mod, params.data, params.size); + + return mod; +} + +TVM_DLL void tvm_runtime_destroy(void * runtime) { + void (*TVMGraphRuntimeRelease)(TVMModuleHandle *); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.release", (TVMFunctionHandle*)&TVMGraphRuntimeRelease)); + TVMGraphRuntimeRelease(&runtime); +} + +TVM_DLL void tvm_runtime_set_input(void * runtime, const char * name, DLTensor * tensor) { + void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char *, DLTensor*); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.set_input", (TVMFunctionHandle*)&TVMGraphRuntime_SetInput)); + TVMGraphRuntime_SetInput(runtime, name, tensor); +} + +TVM_DLL void tvm_runtime_run(void * runtime) { + void (*TVMGraphRuntime_Run)(TVMModuleHandle runtime); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.run", (TVMFunctionHandle*)&TVMGraphRuntime_Run)); + TVMGraphRuntime_Run(runtime); +} + +TVM_DLL void tvm_runtime_get_output(void * runtime, int32_t index, DLTensor * tensor) { + int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor *); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.get_output", (TVMFunctionHandle*)&TVMGraphRuntime_GetOutput)); + TVMGraphRuntime_GetOutput(runtime, index, tensor); +} + diff --git a/apps/bundle_deploy/bundle.cc b/apps/bundle_deploy/bundle.cc index 22f8ba300dec..3e5080927db4 100644 --- a/apps/bundle_deploy/bundle.cc +++ b/apps/bundle_deploy/bundle.cc @@ -21,16 +21,14 @@ #include #include -extern unsigned char build_graph_json[]; -extern unsigned int build_graph_json_len; -extern unsigned char build_params_bin[]; -extern unsigned int build_params_bin_len; - #define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) extern "C" { -TVM_BUNDLE_FUNCTION void *tvm_runtime_create() { +TVM_BUNDLE_FUNCTION void *tvm_runtime_create(const char * build_graph_json, + const char * build_params_bin, + const uint64_t build_params_bin_len) { + const int build_graph_json_len = strlen(build_graph_json); const std::string json_data(&build_graph_json[0], &build_graph_json[0] + build_graph_json_len); tvm::runtime::Module mod_syslib = diff --git a/apps/bundle_deploy/demo.cc b/apps/bundle_deploy/demo.cc index 325bae780260..34be27958c91 100644 --- a/apps/bundle_deploy/demo.cc +++ b/apps/bundle_deploy/demo.cc @@ -17,13 +17,17 @@ * under the License. */ -#include "tvm/runtime/c_runtime_api.h" +#include + #include #include //dlopen -#include #include #include #include +#include + +#include "build/graph.json.c" +#include "build/params.bin.c" template auto getFunc(void *bundle, const char *name) { dlerror(); @@ -34,39 +38,50 @@ template auto getFunc(void *bundle, const char *name) { } int main(int argc, char **argv) { - assert(argc == 2 && "Usage: demo "); + assert(argc == 3 && "Usage: demo "); auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); - auto *handle = getFunc(bundle, "tvm_runtime_create")(); + char * json_data = reinterpret_cast(build_graph_json); + char * params_data = reinterpret_cast(build_params_bin); + uint64_t params_size = build_params_bin_len; - std::vector input_storage(1 * 3 * 224 * 224); - std::mt19937 gen(0); - for (auto &e : input_storage) { - e = std::uniform_real_distribution(0.0, 1.0)(gen); - } + struct timeval t0, t1, t2, t3, t4, t5; + gettimeofday(&t0, 0); + + auto *handle = getFunc(bundle, "tvm_runtime_create")( + json_data, params_data, params_size); + gettimeofday(&t1, 0); + + float input_storage[1 * 3 * 224 * 224]; + FILE * fp = fopen(argv[2], "rb"); + fread(input_storage, 3 * 224 * 224, 4, fp); + fclose(fp); std::vector input_shape = {1, 3, 224, 224}; DLTensor input; - input.data = input_storage.data(); + input.data = input_storage; input.ctx = DLContext{kDLCPU, 0}; input.ndim = 4; input.dtype = DLDataType{kDLFloat, 32, 1}; input.shape = input_shape.data(); input.strides = nullptr; input.byte_offset = 0; + getFunc(bundle, "tvm_runtime_set_input")( handle, "data", &input); + gettimeofday(&t2, 0); auto *ftvm_runtime_run = (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); + gettimeofday(&t3, 0); - std::vector output_storage(1000); + float output_storage[1000]; std::vector output_shape = {1, 1000}; DLTensor output; - output.data = output_storage.data(); + output.data = output_storage; output.ctx = DLContext{kDLCPU, 0}; output.ndim = 2; output.dtype = DLDataType{kDLFloat, 32, 1}; @@ -76,10 +91,30 @@ int main(int argc, char **argv) { getFunc(bundle, "tvm_runtime_get_output")( handle, 0, &output); - for (auto i = 0; i < output_storage.size(); ++i) { - std::cerr << "output[" << i << "]: " << output_storage[i] << std::endl; + gettimeofday(&t4, 0); + + float max_iter = -std::numeric_limits::max(); + int32_t max_index = -1; + for (auto i = 0; i < 1000; ++i) { + if (output_storage[i] > max_iter) { + max_iter = output_storage[i]; + max_index = i; + } } + getFunc(bundle, "tvm_runtime_destroy")(handle); + gettimeofday(&t5, 0); + + printf("The maximum position in output vector is: %d, with max-value %f.\n", + max_index, max_iter); + printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec-t0.tv_sec)*1000000 + (t1.tv_usec-t0.tv_usec)/1000.f, + (t2.tv_sec-t1.tv_sec)*1000000 + (t2.tv_usec-t1.tv_usec)/1000.f, + (t3.tv_sec-t2.tv_sec)*1000000 + (t3.tv_usec-t2.tv_usec)/1000.f, + (t4.tv_sec-t3.tv_sec)*1000000 + (t4.tv_usec-t3.tv_usec)/1000.f, + (t5.tv_sec-t4.tv_sec)*1000000 + (t5.tv_usec-t4.tv_usec)/1000.f); dlclose(bundle); + return 0; } diff --git a/apps/bundle_deploy/runtime.c b/apps/bundle_deploy/runtime.c new file mode 100644 index 000000000000..6a53aa15f573 --- /dev/null +++ b/apps/bundle_deploy/runtime.c @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* Explicitly declare posix_memalign function */ +#if _POSIX_C_SOURCE < 200112L +#undef _POSIX_C_SOURCE +#define _POSIX_C_SOURCE 200809L +#endif + +/*! Support low-level debugging in MISRA-C runtime */ +#define TVM_CRT_DEBUG 0 + +/*! Maximum supported dimension in NDArray */ +#define TVM_CRT_MAX_NDIM 6 +/*! Maximum supported arguments in generated functions */ +#define TVM_CRT_MAX_ARGS 10 + +/*! Maximum inputs in a GraphRuntimeNode */ +#define GRAPH_RUNTIME_NODE_MAX_INPUTS 300 +/*! Maximum supported contexts in a GraphRuntime */ +#define GRAPH_RUNTIME_MAX_CONTEXTS 1 +/*! Maximum supported nodes in a GraphRuntime */ +#define GRAPH_RUNTIME_MAX_NODES 400 +/*! Maximum input nodes in a GraphRuntime */ +#define GRAPH_RUNTIME_MAX_INPUT_NODES 300 +/*! Maximum nodes in a GraphRuntime for quick entry indexing */ +#define GRAPH_RUNTIME_MAX_NODE_ROW_PTR 300 +/*! Maximum output entries in a GraphRuntime */ +#define GRAPH_RUNTIME_MAX_OUTPUTS 300 + +#include "../../src/runtime/crt/crt_runtime_api.c" +#include "../../src/runtime/crt/crt_backend_api.c" +#include "../../src/runtime/crt/graph_runtime.c" +#include "../../src/runtime/crt/load_json.c" +#include "../../src/runtime/crt/ndarray.c" + diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc new file mode 100644 index 000000000000..643f1adff320 --- /dev/null +++ b/apps/bundle_deploy/test.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include +#include //dlopen +#include +#include +#include +#include +#include + +template auto getFunc(void *bundle, const char *name) { + dlerror(); + auto *f = + reinterpret_cast::type>(dlsym(bundle, name)); + assert(!dlerror()); + return f; +} + +int main(int argc, char **argv) { + assert(argc == 6 && "Usage: test "); + auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + assert(bundle); + + struct stat st; + char * json_data; + char * params_data; + uint64_t params_size; + + FILE * fp = fopen(argv[4], "rb"); + stat(argv[4], &st); + json_data = (char*)malloc(st.st_size); + fread(json_data, st.st_size, 1, fp); + fclose(fp); + + fp = fopen(argv[5], "rb"); + stat(argv[5], &st); + params_data = (char*)malloc(st.st_size); + fread(params_data, st.st_size, 1, fp); + params_size = st.st_size; + fclose(fp); + + struct timeval t0, t1, t2, t3, t4, t5; + gettimeofday(&t0, 0); + + auto *handle = getFunc(bundle, "tvm_runtime_create")( + json_data, params_data, params_size); + gettimeofday(&t1, 0); + + float input_storage[10 * 5]; + fp = fopen(argv[2], "rb"); + fread(input_storage, 10 * 5, 4, fp); + fclose(fp); + + float result_storage[10 * 5]; + fp = fopen(argv[3], "rb"); + fread(result_storage, 10 * 5, 4, fp); + fclose(fp); + + std::vector input_shape = {10, 5}; + DLTensor input; + input.data = input_storage; + input.ctx = DLContext{kDLCPU, 0}; + input.ndim = 2; + input.dtype = DLDataType{kDLFloat, 32, 1}; + input.shape = input_shape.data(); + input.strides = nullptr; + input.byte_offset = 0; + + getFunc(bundle, "tvm_runtime_set_input")( + handle, "x", &input); + gettimeofday(&t2, 0); + + auto *ftvm_runtime_run = + (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + assert(!dlerror()); + ftvm_runtime_run(handle); + gettimeofday(&t3, 0); + + float output_storage[10 * 5]; + std::vector output_shape = {10, 5}; + DLTensor output; + output.data = output_storage; + output.ctx = DLContext{kDLCPU, 0}; + output.ndim = 2; + output.dtype = DLDataType{kDLFloat, 32, 1}; + output.shape = output_shape.data(); + output.strides = nullptr; + output.byte_offset = 0; + + getFunc(bundle, "tvm_runtime_get_output")( + handle, 0, &output); + gettimeofday(&t4, 0); + + for (auto i = 0; i < 10 * 5; ++i) { + assert(fabs(output_storage[i] - result_storage[i]) < 1e-5f); + if (fabs(output_storage[i] - result_storage[i]) >= 1e-5f) { + printf("got %f, expected %f\n", output_storage[i], result_storage[i]); + } + } + + getFunc(bundle, "tvm_runtime_destroy")(handle); + gettimeofday(&t5, 0); + + printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec-t0.tv_sec)*1000000 + (t1.tv_usec-t0.tv_usec)/1000.f, + (t2.tv_sec-t1.tv_sec)*1000000 + (t2.tv_usec-t1.tv_usec)/1000.f, + (t3.tv_sec-t2.tv_sec)*1000000 + (t3.tv_usec-t2.tv_usec)/1000.f, + (t4.tv_sec-t3.tv_sec)*1000000 + (t4.tv_usec-t3.tv_usec)/1000.f, + (t5.tv_sec-t4.tv_sec)*1000000 + (t5.tv_usec-t4.tv_usec)/1000.f); + + free(json_data); + free(params_data); + dlclose(bundle); + + return 0; +} diff --git a/src/runtime/crt/crt_backend_api.c b/src/runtime/crt/crt_backend_api.c new file mode 100644 index 000000000000..e011e47b2576 --- /dev/null +++ b/src/runtime/crt/crt_backend_api.c @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include +#include +#include +#include + +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, + int dtype_bits_hint) { + void* ptr = 0; + assert(nbytes > 0); + unsigned int dtype_bytes = dtype_bits_hint / 8; +#ifdef __ANDROID__ + ptr = memalign(64, nbytes * dtype_bytes); +#else + const int ret = posix_memalign(&ptr, 64, nbytes * dtype_bytes); + (void)ret; + assert(ret == 0); +#endif + return ptr; +} + +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { + free(ptr); + return 0; +} + +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { + TVMParallelGroupEnv env; + env.num_task = 1; + flambda(0, &env, cdata); + return 0; +} + +int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { + snprintf(g_fexecs[g_fexecs_count].name, sizeof(g_fexecs[g_fexecs_count].name), name); + g_fexecs[g_fexecs_count].fexec = ptr; + g_fexecs_count++; + return 0; +} diff --git a/src/runtime/crt/crt_runtime_api.c b/src/runtime/crt/crt_runtime_api.c new file mode 100644 index 000000000000..433ae8ad3457 --- /dev/null +++ b/src/runtime/crt/crt_runtime_api.c @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include +#include +#include +#include + +#include "ndarray.h" +#include "graph_runtime.h" +#include "packed_func.h" + +// Handle internal errors + +static char g_last_error[1024]; + +void TVMAPISetLastError(const char* msg) { + assert(strlen(msg) < sizeof(g_last_error)); + snprintf(g_last_error, sizeof(g_last_error), "%s", msg); +} + +const char* TVMGetLastError(void) { return g_last_error; } + +// Manipulate NDArray on target device + +int TVMArrayAlloc(const tvm_index_t* shape, + int ndim, + int dtype_code, + int dtype_bits, + int dtype_lanes, + int device_type, + int device_id, + TVMArrayHandle* out) { + DLDataType dtype; + dtype.code = dtype_code; + dtype.bits = dtype_bits; + dtype.lanes = dtype_lanes; + DLContext ctx; + ctx.device_type = (DLDeviceType)device_type; + ctx.device_id = device_id; + TVMNDArray arr = TVMNDArray_Empty(ndim, shape, dtype, ctx); + **out = arr.dl_tensor; + return 0; +} + +int TVMArrayFree(TVMArrayHandle handle) { + TVMNDArray arr; + arr.dl_tensor = *handle; + return TVMNDArray_Release(&arr); +} + +void * SystemLibraryCreate() { + return 0; +} + +int TVMModGetFunction(TVMModuleHandle mod, + const char* func_name, + int query_imports, + TVMFunctionHandle *out) { + int status = 0; + if (!strcmp(func_name, "load_params")) { + *out = &TVMGraphRuntime_LoadParams; + } else { + status -1; + } + return status; +} + +int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { + int status = 0; + if (!strcmp(name, "tvm.graph_runtime.create")) { + *out = &TVMGraphRuntimeCreate; + } else if (!strcmp(name, "tvm.graph_runtime.set_input")) { + *out = &TVMGraphRuntime_SetInput; + } else if (!strcmp(name, "tvm.graph_runtime.run")) { + *out = &TVMGraphRuntime_Run; + } else if (!strcmp(name, "tvm.graph_runtime.get_output")) { + *out = &TVMGraphRuntime_GetOutput; + } else if (!strcmp(name, "tvm.graph_runtime.release")) { + *out = &TVMGraphRuntimeRelease; + } else if (!strcmp(name, "runtime.SystemLib")) { + *out = &SystemLibraryCreate; + } else { + char msg[200]; + snprintf(msg, sizeof(msg), "fail to get global: name=%s", name); + TVMAPISetLastError(msg); + status = -1; + } + return status; +} diff --git a/src/runtime/crt/graph_runtime.c b/src/runtime/crt/graph_runtime.c new file mode 100644 index 000000000000..1957d0bead4b --- /dev/null +++ b/src/runtime/crt/graph_runtime.c @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime.c + * \brief implement graph runtime in pure C + */ +#include "graph_runtime.h" + +#ifndef MAX +#define MAX(a, b) (((a) > (b)) ? (a) : (b)) +#endif // MAX + +uint32_t Shape_Accumulate(int64_t * shape, uint32_t ndim) { + int64_t accum = 1; + uint32_t idx; + for (idx = 0; idx < ndim; idx++) { + if (shape[idx] == 0) { break; } + accum *= shape[idx]; + } + return accum; +} + +int NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry, JSONReader * reader) { + int status = 0; + reader->BeginArray(reader); + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "invalid json format: failed to parse `node_id`\n"); + } + reader->ReadUnsignedInteger(reader, &(entry->node_id)); + if (!(reader->NextArrayItem(reader))) { + fprintf(stderr, "invalid json format: failed to parse `index`\n"); + } + reader->ReadUnsignedInteger(reader, &(entry->index)); + if (reader->NextArrayItem(reader)) { + reader->ReadUnsignedInteger(reader, &(entry->version)); + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format: failed to parse `version`\n"); + } + } else { + entry->version = 0; + } + return status; +} + +void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode * node, JSONReader *reader, + TVMOpParam* param) { + int bitmask = 0; + char key[20], value[120]; + memset(param, 0, sizeof(TVMOpParam)); + memset(key, 0, sizeof(key)); + memset(value, 0, sizeof(value)); + reader->BeginObject(reader); + while (reader->NextObjectItem(reader, key)) { + reader->ReadString(reader, value); + if (!strcmp(key, "func_name")) { + snprintf(param->func_name, sizeof(value), "%s", value); + bitmask |= 1; + } else if (!strcmp(key, "num_inputs")) { + param->num_inputs = strtoul(value, 0, 10); + bitmask |= 2; + } else if (!strcmp(key, "num_outputs")) { + param->num_outputs = strtoul(value, 0, 10); + bitmask |= 4; + } else if (!strcmp(key, "flatten_data")) { + param->flatten_data = strtoul(value, 0, 10); + bitmask |= 8; + } else { + fprintf(stderr, "do not support key %s", key); + } + } + if (bitmask != (1|2|4|8)) { fprintf(stderr, "invalid format\n"); } +} + +int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node, JSONReader *reader) { + int status = 0; + reader->BeginObject(reader); + int bitmask = 0; + char key[20]; + while (reader->NextObjectItem(reader, key)) { + if (!strcmp(key, "op")) { + reader->ReadString(reader, node->op_type); + bitmask |= 1; + } else if (!strcmp(key, "name")) { + reader->ReadString(reader, node->name); + bitmask |= 2; + } else if (!strcmp(key, "inputs")) { + size_t count = node->inputs_count; + if (count >= GRAPH_RUNTIME_NODE_MAX_INPUTS) { + fprintf(stderr, "The number of inputs in graph runtime node is greater than expected.\n"); + status = -1; + break; + } + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + TVMGraphRuntimeNodeEntry * inputs = node->inputs + count; + reader->BeginArray(reader); + if (!reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format\n"); + status = -1; + break; + } + reader->ReadUnsignedInteger(reader, &(inputs->node_id)); + if (!reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format\n"); + status = -1; + break; + } + reader->ReadUnsignedInteger(reader, &(inputs->index)); + if (reader->NextArrayItem(reader)) { + reader->ReadUnsignedInteger(reader, &(inputs->version)); + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format\n"); + status = -1; + break; + } + } else { + inputs->version = 0; + } + count++; + } + node->inputs_count = count; + bitmask |= 4; + } else if (!strcmp(key, "attr") || !strcmp(key, "attrs")) { + TVMOpParam param; + + TVMGraphRuntimeNode_LoadAttrs(node, reader, ¶m); + memcpy(&node->param, ¶m, sizeof(param)); + } else if (!strcmp(key, "control_deps")) { + fprintf(stderr, "do not support key %s", key); + status = -1; + } else { + fprintf(stderr, "do not support key %s", key); + status = -1; + } + if (status != 0) { break; } + } + if (bitmask != (1|2|4)) { fprintf(stderr, "invalid format\n"); } + return status; +} + +TVMGraphRuntimeNode TVMGraphRuntimeNodeCreate() { + TVMGraphRuntimeNode node; + memset(&node, 0, sizeof(TVMGraphRuntimeNode)); + node.LoadAttrs = TVMGraphRuntimeNode_LoadAttrs; + node.Load = TVMGraphRuntimeNode_Load; + return node; +} + +int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *reader) { + int status = 0; + int bitmask = 0; + char key[16], type[16]; + uint32_t storage_id_count = 0; + uint32_t dltype_count = 0; + uint32_t shape_count = 0; + uint32_t device_index_count = 0; + reader->BeginObject(reader); + while (reader->NextObjectItem(reader, key)) { + if (!strcmp(key, "dltype")) { + reader->BeginArray(reader); + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->ReadString(reader, type); + if (strcmp(type, "list_str")) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + reader->ReadString(reader, attr->dltype[dltype_count]); + dltype_count++; + } + attr->dltype_count = dltype_count;; + if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + bitmask |= 1; + } else if (!strcmp(key, "storage_id")) { + reader->BeginArray(reader); + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->ReadString(reader, type); + if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + reader->ReadUnsignedInteger(reader, &(attr->storage_id[storage_id_count])); + storage_id_count++; + } + if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + bitmask |= 2; + } else if (!strcmp(key, "shape")) { + reader->BeginArray(reader); + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->ReadString(reader, type); + if (strcmp(type, "list_shape")) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + reader->BeginArray(reader); + reader->ReadInteger(reader, &(attr->shape[shape_count][0])); + uint32_t ndim = 1; + if (reader->NextArrayItem(reader)) { + if (reader->NextArrayItem(reader)) { + reader->ReadInteger(reader, &(attr->shape[shape_count][1])); ndim++; + if (reader->NextArrayItem(reader)) { + reader->ReadInteger(reader, &(attr->shape[shape_count][2])); ndim++; + if (reader->NextArrayItem(reader)) { + reader->ReadInteger(reader, &(attr->shape[shape_count][3])); ndim++; + if (reader->NextArrayItem(reader)) { + reader->ReadInteger(reader, &(attr->shape[shape_count][4])); ndim++; + if (reader->NextArrayItem(reader)) { + reader->ReadInteger(reader, &(attr->shape[shape_count][5])); ndim++; + reader->NextArrayItem(reader); + } + } + } + } + } + } + attr->ndim[shape_count] = ndim; + shape_count++; + } + attr->shape_count = shape_count; + if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + bitmask |= 4; + } else if (!strcmp(key, "device_index")) { + reader->BeginArray(reader); + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->ReadString(reader, type); + if (strcmp(type, "list_int")) { fprintf(stderr, "Invalid json format\n"); } + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + while (reader->NextArrayItem(reader)) { + reader->ReadUnsignedInteger(reader, &(attr->device_index[device_index_count])); + device_index_count++; + } + if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + } else { + reader->BeginArray(reader); + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + reader->ReadString(reader, type); + if (!strcmp(type, "list_int")) { + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + uint32_t temp[GRAPH_RUNTIME_MAX_NODES]; + uint32_t temp_count = 0; + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + reader->ReadUnsignedInteger(reader, &(temp[temp_count])); + temp_count++; + } + } else if (!strcmp(type, "size_t")) { + if (!(reader->NextArrayItem(reader))) { fprintf(stderr, "Invalid json format\n"); } + uint32_t temp; + reader->ReadUnsignedInteger(reader, &temp); + } else { + fprintf(stderr, "cannot skip graph attr %s", key); + } + if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); } + } + } + if (bitmask != (1|2|4)) { fprintf(stderr, "invalid format\n"); } + return status; +} + +int TVMGraphRuntime_Load(TVMGraphRuntime * runtime, JSONReader *reader) { + int status = 0; + reader->BeginObject(reader); + int bitmask = 0; + char key[20]; + while (reader->NextObjectItem(reader, key)) { + if (!strcmp(key, "nodes")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + TVMGraphRuntimeNode * node = runtime->nodes + runtime->nodes_count; + status = TVMGraphRuntimeNode_Load(node, reader); + if (status != 0) { + fprintf(stderr, "failed to load an element in `nodes` field in graph runtime node.\n"); + break; +#if TVM_CRT_DEBUG + } else { + printf("layer %u: `%s` loaded.\n", runtime->nodes_count, node->name); +#endif // TVM_CRT_DEBUG + } + runtime->nodes_count++; + } + bitmask |= 1; + } else if (!strcmp(key, "arg_nodes")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + uint32_t * node = runtime->input_nodes + runtime->input_nodes_count; + reader->ReadUnsignedInteger(reader, node); + runtime->input_nodes_count++; + } + bitmask |= 2; + } else if (!strcmp(key, "node_row_ptr")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + uint32_t count = runtime->node_row_ptr_count; + uint32_t * node = runtime->node_row_ptr + count; + reader->ReadUnsignedInteger(reader, node); + runtime->node_row_ptr_count++; + } + bitmask |= 4; + } else if (!strcmp(key, "heads")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + TVMGraphRuntimeNodeEntry * entry = runtime->outputs + runtime->outputs_count; + status = NodeEntry_Load(entry, reader); + if (status != 0) { + fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n"); + break; + } + runtime->outputs_count++; + } + bitmask |= 8; + } else if (!strcmp(key, "attrs")) { + status = TVMGraphRuntimeGraphAttr_Load(&(runtime->attrs), reader); + if (status != 0) { + fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n"); + break; + } + bitmask |= 16; + } else if (!strcmp(key, "metadata")) { + break; + } else { + fprintf(stderr, "key %s is not supported\n", key); + status = -1; + } + if (status != 0) { break; } + } + if (!(bitmask == (1|2|4|8|16))) { fprintf(stderr, "invalid format\n"); } + return status; +} + +uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime, + uint32_t nid, uint32_t index) { + return runtime->node_row_ptr[nid] + index; +} + +/*! + * \brief Get the input index given the name of input. + * \param name The name of the input. + * \return The index of input. + */ +int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime, const char * name) { + uint32_t i; + int32_t rv = -1; + for (i = 0; i< runtime->input_nodes_count; ++i) { + uint32_t nid = runtime->input_nodes[i]; + if (!strcmp(runtime->nodes[nid].name, name)) { + rv = i; + break; + } + } + if (rv < 0) { + fprintf(stderr, "cannot find \"%s\" among input\n", name); + } + return rv; +} + +/*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data_in The input data. + */ +void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in) { + uint32_t index = runtime->GetInputIndex(runtime, name); + if (index >= runtime->input_nodes_count) { + fprintf(stderr, "given index is greater than num of input nodes.\n"); + } + uint32_t eid = runtime->GetEntryId(runtime, runtime->input_nodes[index], 0); + runtime->data_entry[eid].dl_tensor = *data_in; +} + +int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, + const uint32_t param_size) { + int status = 0; + const char * bptr = param_blob; + uint64_t header, reserved; + header = ((uint64_t*)bptr)[0]; // NOLINT(*) + bptr += sizeof(header); + if (header != kTVMNDArrayListMagic) { + fprintf(stderr, "Invalid parameters file format"); + } + reserved = ((uint64_t*)bptr)[0]; // NOLINT(*) + bptr += sizeof(reserved); + + // read names + char names[GRAPH_RUNTIME_MAX_NODES][80]; + memset(names, 0, sizeof(names)); + uint64_t names_count; + int idx; + names_count = ((uint64_t*)bptr)[0]; // NOLINT(*) + bptr += sizeof(names_count); + for (idx = 0; idx < names_count; idx++) { + uint64_t name_length; + name_length = ((uint64_t*)bptr)[0]; // NOLINT(*) + bptr += sizeof(name_length); + if (name_length >= 80) { + fprintf(stderr, "Error: function name longer than expected.\n"); + } + memcpy(names[idx], bptr, name_length); + bptr += name_length; + } + + // read sizes + uint64_t sz; + sz = ((uint64_t*)bptr)[0]; // NOLINT(*) + bptr += sizeof(sz); + uint32_t size = sz; + if (size != names_count) { + fprintf(stderr, "Invalid parameters file format\n"); + status = -1; + } + + for (idx = 0; idx < size; idx++) { + int32_t in_idx = runtime->GetInputIndex(runtime, names[idx]); + if (!(in_idx >= 0)) { + fprintf(stderr, "Found param for non-existent input: %s\n", names[idx]); + status = -1; + } + uint32_t eid = runtime->GetEntryId(runtime, runtime->input_nodes[in_idx], 0); + if (!(eid < runtime->data_entry_count)) { + fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", + eid, runtime->data_entry_count); + status = -1; + } + + status |= TVMNDArray_Load(&(runtime->data_entry[eid]), &bptr); +#if TVM_CRT_DEBUG + TVMNDArray * entry = &(runtime->data_entry[eid]); + printf("param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", + names[idx], in_idx, eid, entry->dl_tensor.ndim, + ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) +#endif // TVM_CRT_DEBUG + } + + return status; +} + +/*! + * \brief Run all the operations one by one. + */ +void TVMGraphRuntime_Run(TVMGraphRuntime * runtime) { + // setup the array and requirements. + uint32_t idx; + for (idx = 0; idx < runtime->op_execs_count; ++idx) { + if (runtime->op_execs[idx].fexec) { +#if TVM_CRT_DEBUG + printf("calling %s (%d)\n", runtime->op_execs[idx].name, idx); +#endif // TVM_CRT_DEBUG + runtime->op_execs[idx].Call(&(runtime->op_execs[idx])); + } + } +} + +int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out) { + int status = 0; + uint32_t nid = runtime->outputs[idx].node_id; + uint32_t index = runtime->outputs[idx].index; + uint32_t eid = runtime->GetEntryId(runtime, nid, index); + + // copy data section to allocated output tensor + int32_t elem_bytes = out->dtype.bits / 8; + int64_t size = Shape_Accumulate(out->shape, out->ndim); + DLTensor * tensor = &(runtime->data_entry[eid].dl_tensor); + assert(out->ndim == tensor->ndim); + assert(out->dtype.bits == tensor->dtype.bits); + assert(Shape_Accumulate(out->shape, out->ndim) == Shape_Accumulate(tensor->shape, tensor->ndim)); + memcpy(out->data, tensor->data, size * elem_bytes); + return status; +} + +void TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime) { + uint32_t idx, dim; + + // Grab saved optimization plan from graph. + DLDataType vtype[GRAPH_RUNTIME_MAX_NODES]; + TVMGraphRuntimeGraphAttr * attrs = &(runtime->attrs); + for (idx = 0; idx < attrs->dltype_count; idx++) { + vtype[idx] = String2DLDataType(attrs->dltype[idx]); + } + + // Size and device type of each storage pool entry. + TVMGraphRuntimePoolEntry pool_entry[GRAPH_RUNTIME_MAX_NODES]; + memset(pool_entry, 0, sizeof(pool_entry)); + uint32_t pool_entry_count = 0; + // Find the maximum space size. + for (idx = 0; idx < attrs->shape_count; idx++) { + int storage_id = attrs->storage_id[idx]; + // Use the fallback device if no device index is available. + int device_type = runtime->ctxs[0].device_type; + uint32_t size = Shape_Accumulate(attrs->shape[idx], attrs->ndim[idx]); + DLDataType t = vtype[idx]; + uint32_t bits = t.bits * t.lanes; + size_t bytes = ((bits + 7U) / 8U) * size; + + uint32_t sid = storage_id; + if (sid >= pool_entry_count) { + pool_entry_count = sid + 1; + } + pool_entry[sid].size = MAX(pool_entry[sid].size, bytes); + pool_entry[sid].device_type = device_type; + } + + // Allocate the space. + for (idx = 0; idx < pool_entry_count; idx++) { + TVMGraphRuntimePoolEntry pit = pool_entry[idx]; + int64_t shape[TVM_CRT_MAX_NDIM] = {0, }; + TVMContext ctx = runtime->ctxs[0]; + DLDataType dtype = {kDLFloat, 32, 1}; + shape[0] = (pit.size + 3) / 4; + runtime->storage_pool[runtime->storage_pool_count] = TVMNDArray_Empty(1, shape, dtype, ctx); + if (runtime->storage_pool[runtime->storage_pool_count].dl_tensor.data == 0) { + fprintf(stderr, "fail to create storage_pool with idx=%d\n", idx); + } + runtime->storage_pool_count++; + } + + // Assign the pooled entries. A unified memory pool is used to simplifiy + // memory assignment for each node entry. The allocated memory on each device + // is mapped to this pool. + runtime->data_entry_count = runtime->node_row_ptr[runtime->node_row_ptr_count - 1]; + for (idx = 0; idx < runtime->data_entry_count; ++idx) { + size_t storage_id = attrs->storage_id[idx]; + assert(storage_id < runtime->storage_pool_count); + runtime->data_entry[idx] = + TVMNDArray_CreateView(&(runtime->storage_pool[storage_id]), + attrs->shape[idx], attrs->ndim[idx], vtype[idx]); + if (runtime->data_entry[idx].dl_tensor.data == 0) { + fprintf(stderr, "fail to create for node with idx=%d, storage_id=%d\n", idx, storage_id); + } + } +} + +int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime * runtime) { + int status = 0; + uint32_t nid, idx; + runtime->op_execs_count = runtime->nodes_count; + for (nid = 0; nid < runtime->nodes_count; nid++) { + const TVMGraphRuntimeNode * inode = runtime->nodes + nid; + if (strcmp(inode->op_type, "null")) { + DLTensorPtr args[GRAPH_RUNTIME_MAX_NODES]; + uint32_t args_count = 0; + for (idx = 0; idx < inode->inputs_count; idx++) { + const TVMGraphRuntimeNodeEntry * entry = inode->inputs + idx; + uint32_t eid = runtime->GetEntryId(runtime, entry->node_id, entry->index); + args[idx] = &(runtime->data_entry[eid].dl_tensor); + args_count++; + } + for (idx = 0; idx < inode->param.num_outputs; idx++) { + uint32_t eid = runtime->GetEntryId(runtime, nid, idx); + args[args_count] = &(runtime->data_entry[eid].dl_tensor); + args_count++; + } + if (strcmp(inode->op_type, "tvm_op")) { + fprintf(stderr, "Can only take tvm_op as op\n"); status = -1; + break; + } + if (args_count >= TVM_CRT_MAX_ARGS) { + fprintf(stderr, "too many arguments: expected less than %d args, but got %d.\n", + TVM_CRT_MAX_ARGS, args_count); + status = -1; + break; + } +#if TVM_CRT_DEBUG + printf("creating tvm_op: %s with node_id=%d\n", inode->param.func_name, nid); +#endif // TVM_CRT_DEBUG + TVMPackedFunc pf; + runtime->CreateTVMOp(runtime, &(inode->param), args, args_count, inode->inputs_count, &pf); + runtime->op_execs[nid] = pf; + } + } + return status; +} + +typedef struct TVMOpArgs { + DLTensor args[TVM_CRT_MAX_ARGS]; + uint32_t args_count; + TVMValue arg_values[TVM_CRT_MAX_ARGS]; + uint32_t arg_values_count; + uint32_t arg_tcodes[TVM_CRT_MAX_ARGS]; + uint32_t arg_tcodes_count; + int64_t shape_data[TVM_CRT_MAX_ARGS]; + uint32_t shape_data_count; +} TVMOpArgs; + +int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam * param, + DLTensorPtr * args, const uint32_t args_count, + uint32_t num_inputs, TVMPackedFunc * pf) { + uint32_t idx; + TVMOpArgs arg_ptr; + memset(&arg_ptr, 0, sizeof(TVMOpArgs)); + arg_ptr.args_count = args_count; + if (param->flatten_data) { + arg_ptr.shape_data_count = arg_ptr.args_count; + } + for (idx = 0; idx < arg_ptr.args_count; ++idx) { + TVMValue v; + memset(&v, 0, sizeof(v)); + DLTensor * t = &(arg_ptr.args[idx]); + /* v.v_handle = &((*args)[idx]); */ + v.v_handle = args[idx]; + arg_ptr.arg_values[idx] = v; + arg_ptr.arg_values_count++; + arg_ptr.arg_tcodes[idx] = kTVMNDArrayHandle; + arg_ptr.arg_tcodes_count++; + if (param->flatten_data) { + arg_ptr.shape_data[idx] = Shape_Accumulate(t->shape, t->ndim); + t->ndim = 1; + t->shape[0] = arg_ptr.shape_data[idx]; + } + } + if (!strcmp(param->func_name, "__nop") || !strcmp(param->func_name, "__copy")) { + fprintf(stderr, "%s function is not yet supported.", param->func_name); + } + + runtime->module.GetFunction(param->func_name, pf); + TVMArgs targs = TVMArgs_Create(arg_ptr.arg_values, arg_ptr.arg_tcodes, arg_ptr.arg_values_count); + pf->SetArgs(pf, &targs); + + return 0; +} + +/*! + * \brief Initialize the graph executor with graph and context. + * \param graph_json The execution graph. + * \param module The module containing the compiled functions for the host + * processor. + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + */ +void TVMGraphRuntime_Init(TVMGraphRuntime * runtime, const char * graph_json, + const TVMModule * module, const TVMContext * ctxs) { + JSONReader reader = JSONReader_Create(graph_json); + runtime->Load(runtime, &reader); + runtime->ctxs[0] = ctxs[0]; + runtime->SetupStorage(runtime); + runtime->SetupOpExecs(runtime); + JSONReader_Release(&reader); +} + +TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, + const TVMModule * m, const TVMContext * ctxs) { + TVMGraphRuntime * runtime = (TVMGraphRuntime*)malloc(sizeof(TVMGraphRuntime)); // NOLINT(*) + memset(runtime, 0, sizeof(TVMGraphRuntime)); + runtime->GetEntryId = TVMGraphRuntime_GetEntryId; + runtime->GetInputIndex = TVMGraphRuntime_GetInputIndex; + runtime->Init = TVMGraphRuntime_Init; + runtime->Load = TVMGraphRuntime_Load; + runtime->SetInput = TVMGraphRuntime_SetInput; + runtime->LoadParams = TVMGraphRuntime_LoadParams; + runtime->Run = TVMGraphRuntime_Run; + runtime->GetOutput = TVMGraphRuntime_GetOutput; + runtime->SetupStorage = TVMGraphRuntime_SetupStorage; + runtime->SetupOpExecs = TVMGraphRuntime_SetupOpExecs; + runtime->CreateTVMOp = TVMGraphRuntime_CreateTVMOp; + runtime->module.GetFunction = TVMModule_GetFunction; + // init + runtime->Init(runtime, sym_json, m, ctxs); + return runtime; +} + +void TVMGraphRuntimeRelease(TVMGraphRuntime ** pptr) { + int32_t idx; + TVMGraphRuntime * runtime = *pptr; + for (idx = 0; idx < runtime->storage_pool_count; ++idx) { + TVMNDArray_Release(&(runtime->storage_pool[idx])); + } + free(*pptr); +} diff --git a/src/runtime/crt/graph_runtime.h b/src/runtime/crt/graph_runtime.h new file mode 100644 index 000000000000..7fe395c5b09c --- /dev/null +++ b/src/runtime/crt/graph_runtime.h @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file graph_runtime.h + * \brief Tiny graph runtime that can run graph containing only tvm PackedFunc. + */ +#ifndef TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_ +#define TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_ + +#include + +#include "load_json.h" +#include "ndarray.h" +#include "packed_func.h" +#include "module.h" + +/*! \brief operator attributes about tvm op */ +typedef struct TVMOpParam { + char func_name[120]; + uint32_t num_inputs; + uint32_t num_outputs; + uint32_t flatten_data; +} TVMOpParam; + +// Memory pool entry. +typedef struct TVMGraphRuntimePoolEntry { + size_t size; + int device_type; +} TVMGraphRuntimePoolEntry; + +// Node entry +typedef struct TVMGraphRuntimeNodeEntry { + uint32_t node_id; + uint32_t index; + uint32_t version; + // JSON Loader + void (*Load)(JSONReader *reader); +} TVMGraphRuntimeNodeEntry; + +// Node +typedef struct TVMGraphRuntimeNode { + // operator type in string + char op_type[16]; + // name of the op + char name[120]; + // parameters + TVMOpParam param; + // inputs + TVMGraphRuntimeNodeEntry inputs[GRAPH_RUNTIME_NODE_MAX_INPUTS]; + size_t inputs_count; + // control deps + uint32_t control_deps[200]; + // JSON Loader + void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param); + // JSON Loader + int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader); +} TVMGraphRuntimeNode; + +// Graph attribute +typedef struct TVMGraphRuntimeGraphAttr { + uint32_t storage_num_not_alloctaed; + uint32_t storage_id[GRAPH_RUNTIME_MAX_NODES]; + uint32_t device_index[GRAPH_RUNTIME_MAX_NODES]; + char dltype[GRAPH_RUNTIME_MAX_NODES][10]; // "int8", "int16", "float32" + uint32_t dltype_count; + int64_t shape[GRAPH_RUNTIME_MAX_NODES][TVM_CRT_MAX_NDIM]; + uint32_t ndim[GRAPH_RUNTIME_MAX_NODES]; + uint32_t shape_count; +} TVMGraphRuntimeGraphAttr; + +typedef DLTensor* DLTensorPtr; + +/*! + * \brief Tiny graph runtime. + * + * This runtime can be acccesibly in various language via + * TVM runtime PackedFunc API. + */ +/* class GraphRuntime : public ModuleNode { */ +typedef struct TVMGraphRuntime { + void (*Run)(struct TVMGraphRuntime * runtime); + + /*! + * \brief Initialize the graph executor with graph and context. + * \param graph_json The execution graph. + * \param module The module containing the compiled functions for the host + * processor. + * \param ctxs The context of the host and devices where graph nodes will be + * executed on. + */ + void (*Init)(struct TVMGraphRuntime * runtime, + const char * graph_json, + const TVMModule * module, + const TVMContext * ctxs); + + /*! + * \brief Get the input index given the name of input. + * \param name The name of the input. + * \return The index of input. + */ + int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name); + + /*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data_in The input data. + */ + void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out); + /*! + * \brief Load parameters from parameter blob. + * \param param_blob A binary blob of parameter. + */ + int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob, + const uint32_t param_size); + + // The graph attribute fields. + int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader); + /*! \brief Setup the temporal storage */ + void (*SetupStorage)(struct TVMGraphRuntime * runtime); + /*! \brief Setup the executors. */ + int (*SetupOpExecs)(struct TVMGraphRuntime * runtime); + + /*! + * \brief Create an execution function given input. + * \param attrs The node attributes. + * \param args The arguments to the functor, including inputs and outputs. + * \param num_inputs Number of inputs. + * \return The created executor. + */ + int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs, + DLTensorPtr * args, const uint32_t args_count, + uint32_t num_inputs, TVMPackedFunc * pf); + + // Get node entry index. + uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index); + + // /*! \brief The graph nodes. */ + /* GraphRuntimeNode nodes_[GRAPH_RUNTIME_MAX_NODES]; */ + TVMGraphRuntimeNode nodes[GRAPH_RUNTIME_MAX_NODES]; + uint32_t nodes_count; + /*! \brief The argument nodes. */ + uint32_t input_nodes[GRAPH_RUNTIME_MAX_INPUT_NODES]; + uint32_t input_nodes_count; + /*! \brief Used for quick entry indexing. */ + uint32_t node_row_ptr[GRAPH_RUNTIME_MAX_NODE_ROW_PTR]; + uint32_t node_row_ptr_count; + /*! \brief Output entries. */ + TVMGraphRuntimeNodeEntry outputs[GRAPH_RUNTIME_MAX_OUTPUTS]; + uint32_t outputs_count; + /*! \brief Additional graph attributes. */ + TVMGraphRuntimeGraphAttr attrs; + /*! \brief The code module that contains both host and device code. */ + TVMModule module; + /*! \brief Execution context of all devices including the host. */ + TVMContext ctxs[GRAPH_RUNTIME_MAX_CONTEXTS]; + uint32_t ctxs_count; + /*! \brief Common storage pool for all devices. */ + TVMNDArray storage_pool[GRAPH_RUNTIME_MAX_NODES]; + uint32_t storage_pool_count; + /*! \brief Data entry of each node. */ + TVMNDArray data_entry[GRAPH_RUNTIME_MAX_NODES]; + uint32_t data_entry_count; + /*! \brief Operator on each node. */ + TVMPackedFunc op_execs[GRAPH_RUNTIME_MAX_NODES]; + uint32_t op_execs_count; +} TVMGraphRuntime; + +// public functions +TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m, + const TVMContext * ctxs); +void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime); + +// private functions +void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); +int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, + const uint32_t param_size); +void TVMGraphRuntime_Run(TVMGraphRuntime * runtime); +int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out); + +#endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_ diff --git a/src/runtime/crt/load_json.c b/src/runtime/crt/load_json.c new file mode 100644 index 000000000000..894ab8938a10 --- /dev/null +++ b/src/runtime/crt/load_json.c @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file load_json.c + * \brief Load graph from JSON file. + */ +#include "load_json.h" + +// the node entry structure in serialized format +typedef struct JSONNodeEntry { + uint32_t node_id; + uint32_t index; + uint32_t version; + void (*Load)(struct JSONNodeEntry * entry, JSONReader *reader); +} JSONNodeEntry; + +void JSONNodeEntryLoad(JSONNodeEntry * entry, JSONReader *reader) { + reader->BeginArray(reader); + if (reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); } + reader->ReadUnsignedInteger(reader, &(entry->node_id)); + if (reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); } + reader->ReadUnsignedInteger(reader, &(entry->index)); + if (reader->NextArrayItem(reader)) { + reader->ReadUnsignedInteger(reader, &(entry->version)); + if (!reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); } + } else { + entry->version = 0; + } +} + + +// implementation of Seq class + +void SeqPush(Seq * seq, uint32_t src) { + if (seq->size >= seq->allocated) { + printf("seq too large.\n"); + } + seq->data[seq->size] = src; + seq->size += 1; +} + +uint32_t * SeqBack(Seq * seq) { + if (seq->size >= seq->allocated) { + printf("seq too large.\n"); + } + return seq->data + (seq->size-1); +} + +void SeqPop(Seq * seq) { + if (seq->size >= seq->allocated) { + printf("seq size is too large.\n"); + } + if (seq->size == 0) { + printf("seq size is too small.\n"); + } + seq->size -= 1; +} + +Seq * SeqCreate(uint64_t len) { + Seq * seq = (Seq*)malloc(sizeof(Seq)); // NOLINT(*) + memset(seq, 0, sizeof(Seq)); + seq->allocated = len; + seq->data = (uint32_t*)malloc(sizeof(uint32_t)*len); // NOLINT(*) + seq->push_back = SeqPush; + seq->back = SeqBack; + seq->pop_back = SeqPop; + return seq; +} + +void SeqRelease(Seq ** seq) { + free((*seq)->data); + free(*seq); +} + + +// implementations of JSONReader + +/*! + * \brief Takes the next char from the input source. + * \return the next character. + */ +char JSONReader_NextChar(JSONReader * reader) { + char ch = reader->isptr[0]; + reader->isptr += 1; + return ch; +} + +/*! + * \brief Returns the next char from the input source. + * \return the next character. + */ +char JSONReader_PeekNextChar(JSONReader * reader) { + return reader->isptr[0]; +} + +/*! + * \brief Read next nonspace character. + * \return the next nonspace character. + */ +char JSONReader_NextNonSpace(JSONReader * reader) { + int ch; + do { + ch = reader->NextChar(reader); + if (ch == '\n') { ++(reader->line_count_n_); } + if (ch == '\r') { ++(reader->line_count_r_); } + } while (isspace(ch)); + return ch; +} + +/*! + * \brief Read just before next nonspace but not read that. + * \return the next nonspace character. + */ +char JSONReader_PeekNextNonSpace(JSONReader * reader) { + int ch; + while (1) { + ch = reader->PeekNextChar(reader); + if (ch == '\n') { ++(reader->line_count_n_); } + if (ch == '\r') { ++(reader->line_count_r_); } + if (!isspace(ch)) break; + reader->NextChar(reader); + } + return ch; +} + +/*! + * \brief Parse next JSON string. + * \param out_str the output string. + * \throw dmlc::Error when next token is not string + */ +int JSONReader_ReadString(JSONReader * reader, char * out_str) { + int status = 0; + char ch = reader->NextNonSpace(reader); + char output[128]; + uint32_t output_counter = 0; + memset(output, 0, 128); + while (1) { + ch = reader->NextChar(reader); + if (ch == '\\') { + char sch = reader->NextChar(reader); + switch (sch) { + case 'r': snprintf(output, sizeof(output), "%s\r", output); break; + case 'n': snprintf(output, sizeof(output), "%s\n", output); break; + case '\\': snprintf(output, sizeof(output), "%s\\", output); break; + case 't': snprintf(output, sizeof(output), "%s\t", output); break; + case '\"': snprintf(output, sizeof(output), "%s\"", output); break; + default: fprintf(stderr, "unknown string escape %c\n", sch); + } + } else { + if (ch == '\"') { break; } + if (strlen(output) >= 127) { + fprintf(stderr, "Error: detected buffer overflow.\n"); + status = -1; + break; + } + strncat(output, &ch, 1); + output_counter++; + if (output_counter >= 127) { + fprintf(stderr, "Error: string size greater than 128.\n"); + status = -1; + break; + } + } + if (ch == EOF || ch == '\r' || ch == '\n') { + fprintf(stderr, "Error at line X, Expect \'\"\' but reach end of line\n"); + } + } + snprintf(out_str, sizeof(output), "%s", output); + return status; +} + +int JSONReader_ReadUnsignedInteger(JSONReader * reader, unsigned int * out_value) { + int status = 0; + char* endptr; + const char* icstr = reader->isptr; + unsigned int number = strtol(icstr, &endptr, 10); + reader->isptr += endptr - icstr; + *out_value = number; + return status; +} + + +int JSONReader_ReadInteger(JSONReader * reader, int64_t * out_value) { + int status = 0; + char* endptr; + const char* icstr = reader->isptr; + int64_t number = strtol(icstr, &endptr, 10); + reader->isptr += endptr - icstr; + *out_value = number; + return status; +} + +/*! + * \brief Begin parsing an object. + * \code + * string key; + * // value can be any type that is json serializable. + * string value; + * reader->BeginObject(); + * while (reader->NextObjectItem(&key)) { + * // do somthing to key value + * reader->Read(&value); + * } + * \endcode + */ +void JSONReader_BeginObject(JSONReader * reader) { + int ch = reader->NextNonSpace(reader); + if (!(ch == '{')) { + fprintf(stderr, "Error at line X, Expect \'{\' but got \'%c\'\n", ch); + } + Seq * scope_counter_ = reader->scope_counter_; + scope_counter_->push_back(scope_counter_, 0); +} + +/*! + * \brief Try to move to next object item. + * If this call is successful, user can proceed to call + * reader->Read to read in the value. + * \param out_key the key to the next object. + * \return true if the read is successful, false if we are at end of the object. + */ +uint8_t JSONReader_NextObjectItem(JSONReader * reader, char * out_key) { + uint8_t next = 1; + Seq * scope_counter_ = reader->scope_counter_; + if (scope_counter_->back(scope_counter_)[0] != 0) { + int ch = reader->NextNonSpace(reader); + if (ch == EOF) { + next = 0; + } else if (ch == '}') { + next = 0; + } else { + if (ch != ',') { + fprintf(stderr, "Error at line X, JSON object expect \'}\' or \',\' but got \'%c\'\n", ch); + } + } + } else { + int ch = reader->PeekNextNonSpace(reader); + if (ch == '}') { + reader->NextChar(reader); + next = 0; + } + } + if (!next) { + scope_counter_->pop_back(scope_counter_); + return 0; + } else { + scope_counter_->back(scope_counter_)[0] += 1; + reader->ReadString(reader, out_key); + int ch = reader->NextNonSpace(reader); + if (ch != ':') { + fprintf(stderr, "Error at line X, Expect \':\' but get \'%c\'\n", ch); + } + return 1; + } +} + +/*! + * \brief Begin parsing an array. + * \code + * // value can be any type that is json serializable. + * string value; + * reader->BeginArray(); + * while (reader->NextArrayItem(&value)) { + * // do somthing to value + * } + * \endcode + */ +void JSONReader_BeginArray(JSONReader * reader) { + int ch = reader->NextNonSpace(reader); + if (ch != '[') { + fprintf(stderr, "Error at line X, Expect \'[\' but get \'%c\'\n", ch); + } + Seq * scope_counter_ = reader->scope_counter_; + scope_counter_->push_back(scope_counter_, 0); +} + +/*! + * \brief Try to read the next element in the array. + * If this call is successful, user can proceed to call + * reader->Read to read in the value. + * \return true if the read is successful, false if we are at end of the array. + */ +uint8_t JSONReader_NextArrayItem(JSONReader * reader) { + uint8_t next = 1; + Seq * scope_counter_ = reader->scope_counter_; + if (scope_counter_->back(scope_counter_)[0] != 0) { + int ch = reader->NextNonSpace(reader); + if (ch == EOF) { + next = 0; + } else if (ch == ']') { + next = 0; + } else { + if (ch != ',') { + fprintf(stderr, "Error at line X, JSON object expect \']\' or \',\' but got \'%c\'\n", ch); + } + } + } else { + int ch = reader->PeekNextNonSpace(reader); + if (ch == ']') { + reader->NextChar(reader); + next = 0; + } + } + if (!next) { + scope_counter_->pop_back(scope_counter_); + return 0; + } else { + scope_counter_->back(scope_counter_)[0] += 1; + return 1; + } +} + +/*! + * \brief Constructor. + * \param is the input source. + */ +JSONReader JSONReader_Create(const char * is) { + JSONReader reader; + memset(&reader, 0, sizeof(JSONReader)); + reader.scope_counter_ = SeqCreate(200); + reader.NextChar = JSONReader_NextChar; + reader.PeekNextChar = JSONReader_PeekNextChar; + reader.NextNonSpace = JSONReader_NextNonSpace; + reader.PeekNextNonSpace = JSONReader_PeekNextNonSpace; + reader.ReadString = JSONReader_ReadString; + reader.ReadUnsignedInteger = JSONReader_ReadUnsignedInteger; + reader.ReadInteger = JSONReader_ReadInteger; + reader.BeginArray = JSONReader_BeginArray; + reader.BeginObject = JSONReader_BeginObject; + reader.NextArrayItem = JSONReader_NextArrayItem; + reader.NextObjectItem = JSONReader_NextObjectItem; + reader.is_ = (char*)malloc(strlen(is)+1); // NOLINT(*) + memset(reader.is_, 0, strlen(is)+1); + snprintf(reader.is_, strlen(is)+1, "%s", is); + reader.isptr = reader.is_; + return reader; +} + +void JSONReader_Release(JSONReader * reader) { + SeqRelease(&(reader->scope_counter_)); + free(reader->is_); +} diff --git a/src/runtime/crt/load_json.h b/src/runtime/crt/load_json.h new file mode 100644 index 000000000000..a5df7a055af0 --- /dev/null +++ b/src/runtime/crt/load_json.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file load_json.h + * \brief Lightweight JSON Reader that read save into C++ data structs. + */ +#ifndef TVM_RUNTIME_CRT_LOAD_JSON_H_ +#define TVM_RUNTIME_CRT_LOAD_JSON_H_ + +#include +#include + +enum { + JSON_READ_TYPE_U8 = 1, + JSON_READ_TYPE_S8 = 2, + JSON_READ_TYPE_U16 = 3, + JSON_READ_TYPE_S16 = 4, + JSON_READ_TYPE_U32 = 5, + JSON_READ_TYPE_S32 = 6, + JSON_READ_TYPE_F32 = 7, + JSON_READ_TYPE_F64 = 8, + JSON_READ_TYPE_GRAPH_RUNTIME_NODE = 9, + JSON_READ_TYPE_GRAPH_RUNTIME_NODE_ENTRY = 10, + JSON_READ_TYPE_GRAPH_RUNTIME_GRAPH_ATTR = 11 +}; + +typedef struct Seq { + uint32_t * data; + uint64_t allocated; + uint32_t size; + void (*push_back)(struct Seq * seq, uint32_t src); + uint32_t * (*back)(struct Seq * seq); + void (*pop_back)(struct Seq * seq); +} Seq; + +/*! + * \brief Lightweight JSON Reader to read any STL compositions and structs. + * The user need to know the schema of the + */ +typedef struct JSONReader { + /*! \brief internal reader string */ + char * is_; + char * isptr; + /*! \brief "\\r" counter */ + size_t line_count_r_; + /*! \brief "\\n" counter */ + size_t line_count_n_; + /*! + * \brief record how many element processed in + * current array/object scope. + */ + Seq * scope_counter_; + + char (*NextChar)(struct JSONReader * reader); + char (*NextNonSpace)(struct JSONReader * reader); + char (*PeekNextChar)(struct JSONReader * reader); + char (*PeekNextNonSpace)(struct JSONReader * reader); + int (*ReadUnsignedInteger)(struct JSONReader * reader, unsigned int * out_value); + int (*ReadInteger)(struct JSONReader * reader, int64_t * out_value); + int (*ReadString)(struct JSONReader * reader, char * out_value); + void (*BeginArray)(struct JSONReader * reader); + void (*BeginObject)(struct JSONReader * reader); + uint8_t (*NextObjectItem)(struct JSONReader * reader, char * out_key); + uint8_t (*NextArrayItem)(struct JSONReader * reader); +} JSONReader; + +/*! + * \brief Constructor of JSONReader class + * \param is the input source. + */ +JSONReader JSONReader_Create(const char * is); + +void JSONReader_Release(JSONReader * reader); + +#endif // TVM_RUNTIME_CRT_LOAD_JSON_H_ diff --git a/src/runtime/crt/module.h b/src/runtime/crt/module.h new file mode 100644 index 000000000000..8ff979b872e6 --- /dev/null +++ b/src/runtime/crt/module.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/crt/module.h + * \brief Runtime container of the functions + */ +#ifndef TVM_RUNTIME_CRT_MODULE_H_ +#define TVM_RUNTIME_CRT_MODULE_H_ + +#include +#include + +struct TVMPackedFunc; +typedef struct TVMPackedFunc TVMPackedFunc; + +/*! + * \brief Module container of TVM. + */ +typedef struct TVMModule { + /*! + * \brief Get packed function from current module by name. + * + * \param name The name of the function. + * \param pf The result function. + * + * This function will return PackedFunc(nullptr) if function do not exist. + */ + void (*GetFunction)(const char * name, TVMPackedFunc * pf); +} TVMModule; + +#endif // TVM_RUNTIME_CRT_MODULE_H_ diff --git a/src/runtime/crt/ndarray.c b/src/runtime/crt/ndarray.c new file mode 100644 index 000000000000..016fdd5add95 --- /dev/null +++ b/src/runtime/crt/ndarray.c @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ndarray.c + * \brief NDArray container infratructure. + */ + +#include "ndarray.h" + +TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape, + DLDataType dtype, DLContext ctx) { + TVMNDArray ret; + memset(&ret, 0, sizeof(TVMNDArray)); + ret.dl_tensor.ndim = ndim; + ret.dl_tensor.shape = (int64_t*)malloc(sizeof(int64_t)*ndim); // NOLINT(*) + memcpy(ret.dl_tensor.shape, shape, sizeof(int64_t)*ndim); + ret.dl_tensor.dtype = dtype; + ret.dl_tensor.ctx = ctx; + ret.dl_tensor.data = 0; + return ret; +} + +TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, + DLDataType dtype, DLContext ctx) { + TVMNDArray ret = TVMNDArray_Create(ndim, shape, dtype, ctx); + int64_t num_elems = 1; + int elem_bytes = (dtype.bits + 7) / 8; + uint32_t idx; + for (idx = 0; idx < ret.dl_tensor.ndim; ++idx) { + num_elems *= shape[idx]; + } + ret.dl_tensor.data = TVMBackendAllocWorkspace(kDLCPU, 0, num_elems, dtype.code, dtype.bits); + memset(ret.dl_tensor.data, 0, num_elems * elem_bytes); + return ret; +} + +int TVMNDArray_Load(TVMNDArray * ret, const char ** strm) { + int32_t status = 0; + uint64_t header, reserved; + header = ((uint64_t*)*strm)[0]; *strm += sizeof(header); // NOLINT(*) + if (header != kTVMNDArrayMagic) { + fprintf(stderr, "Invalid DLTensor file format\n"); + status = -1; + } + reserved = ((uint64_t*)*strm)[0]; *strm += sizeof(reserved); // NOLINT(*) + DLContext ctx; + uint32_t ndim; + DLDataType dtype; + ctx = ((DLContext*)*strm)[0]; *strm += sizeof(ctx); // NOLINT(*) + ndim = ((uint32_t*)*strm)[0]; *strm += sizeof(ndim); // NOLINT(*) + dtype = ((DLDataType*)*strm)[0]; *strm += sizeof(dtype); // NOLINT(*) + if ((ndim <= 0) || (ndim > TVM_CRT_MAX_NDIM)) { + fprintf(stderr, "Invalid ndim=%d: expected to be 1 ~ %d.\n", ndim, TVM_CRT_MAX_NDIM); + status = -1; + } + if (ctx.device_type != kDLCPU) { + fprintf(stderr, "Invalid DLTensor context: can only save as CPU tensor\n"); + status = -1; + } + int64_t shape[TVM_CRT_MAX_NDIM]; + uint32_t idx; + if (ndim != 0) { + for (idx = 0; idx < ndim; idx++) { + shape[idx] = ((int64_t*)*strm)[0]; *strm += sizeof(shape[idx]); // NOLINT(*) + } + } + *ret = TVMNDArray_Empty(ndim, shape, dtype, ctx); + int64_t num_elems = 1; + int elem_bytes = (ret->dl_tensor.dtype.bits + 7) / 8; + for (idx = 0; idx < ret->dl_tensor.ndim; ++idx) { + num_elems *= ret->dl_tensor.shape[idx]; + } + int64_t data_byte_size; + data_byte_size = ((int64_t*)*strm)[0]; *strm += sizeof(data_byte_size); // NOLINT(*) + if (!(data_byte_size == num_elems * elem_bytes)) { + fprintf(stderr, "invalid DLTensor file format: data_byte_size=%ld, " + "while num_elems*elem_bytes=%ld\n", + data_byte_size, (num_elems * elem_bytes)); + status = -1; + } + memcpy(ret->dl_tensor.data, *strm, data_byte_size); + *strm += data_byte_size; + + return status; +} + +TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape, + uint32_t ndim, DLDataType dtype) { + TVMNDArray ret = TVMNDArray_Create(ndim, shape, dtype, arr->dl_tensor.ctx); + ret.dl_tensor.data = arr->dl_tensor.data; + return ret; +} + +int TVMNDArray_Release(TVMNDArray * arr) { + free(arr->dl_tensor.data); + free(arr->dl_tensor.shape); + return 0; +} diff --git a/src/runtime/crt/ndarray.h b/src/runtime/crt/ndarray.h new file mode 100644 index 000000000000..dde23ca6cd41 --- /dev/null +++ b/src/runtime/crt/ndarray.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/crt/ndarray.h + * \brief Abstract device memory management API + */ +#ifndef TVM_RUNTIME_CRT_NDARRAY_H_ +#define TVM_RUNTIME_CRT_NDARRAY_H_ + +#include +#include +#include + +#include +#include +#include + +/*! \brief Magic number for NDArray file */ +static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; + +/*! \brief Magic number for NDArray list file */ +static const uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; + +typedef struct TVMNDArray { + DLTensor dl_tensor; +} TVMNDArray; + +TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape, + DLDataType dtype, DLContext ctx); + +TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, + DLDataType dtype, DLContext ctx); + +int TVMNDArray_Load(TVMNDArray * ret, const char ** strm); + +TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape, + uint32_t ndim, DLDataType dtype); + +int TVMNDArray_Release(TVMNDArray * arr); + +#endif // TVM_RUNTIME_CRT_NDARRAY_H_ diff --git a/src/runtime/crt/packed_func.h b/src/runtime/crt/packed_func.h new file mode 100644 index 000000000000..21370b69c8c0 --- /dev/null +++ b/src/runtime/crt/packed_func.h @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/packed_func.h + * \brief Type-erased function used across TVM API. + */ +#ifndef TVM_RUNTIME_CRT_PACKED_FUNC_H_ +#define TVM_RUNTIME_CRT_PACKED_FUNC_H_ + +#include + +#include +#include +#include + +#include "module.h" + +static inline DLDataType String2DLDataType(const char * s) { + DLDataType t; + // handle None type + if (strlen(s) == 0) { + t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + return t; + } + t.bits = 32; t.lanes = 1; + const char* scan; + if (!strncmp(s, "int", 3)) { + t.code = kDLInt; scan = s + 3; + } else if (!strncmp(s, "uint", 4)) { + t.code = kDLUInt; scan = s + 4; + } else if (!strncmp(s, "float", 5)) { + t.code = kDLFloat; scan = s + 5; + } else if (!strncmp(s, "handle", 6)) { + t.code = kTVMOpaqueHandle; + t.bits = 64; // handle uses 64 bit by default. + scan = s + 6; + } else if (!strcmp(s, "bool")) { + t.code = kDLUInt; + t.bits = 1; + t.lanes = 1; + return t; + } else { + scan = s; + fprintf(stderr, "unknown type %s\n", s); + } + char* xdelim; + uint8_t bits = (uint8_t)(strtoul(scan, &xdelim, 10)); + if (bits != 0) t.bits = bits; + char* endpt = xdelim; + if (*xdelim == 'x') { + t.lanes = (uint16_t)(strtoul(xdelim + 1, &endpt, 10)); + } + if (!(endpt == s + strlen(s))) { + fprintf(stderr, "unknown type %s\n", s); + } + return t; +} + +typedef struct TVMArgs { + TVMValue values[TVM_CRT_MAX_ARGS]; + int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ + uint32_t values_count; +} TVMArgs; + +static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint32_t values_count) { + uint32_t idx; + TVMArgs args; + memset(&args, 0, sizeof(args)); + for (idx = 0; idx < values_count; idx++) { + memcpy(args.values + idx, values + idx, sizeof(TVMValue)); + args.tcodes[idx] = tcodes[idx]; + } + args.values_count = values_count; + return args; +} + +static inline int TVMNoOperation(TVMValue * args, int * type_codes, int num_args, + TVMRetValueHandle ret, void * res) { + return 0; +} + +typedef struct TVMPackedFunc { + char name[200]; + TVMPackedCFunc fexec; + TVMArgs args; + void (*Call)(struct TVMPackedFunc * pf); + void (*SetArgs)(struct TVMPackedFunc * pf, const struct TVMArgs * args); +} TVMPackedFunc; + +static inline void TVMPackedFunc_Call(TVMPackedFunc * pf) { + pf->fexec(pf->args.values, pf->args.tcodes, pf->args.values_count, 0, 0); +} + +static inline void TVMPackedFunc_SetArgs(TVMPackedFunc * pf, const TVMArgs * args) { + memcpy(&(pf->args), args, sizeof(TVMArgs)); +} + +TVMPackedFunc g_fexecs[GRAPH_RUNTIME_MAX_NODES]; +uint32_t g_fexecs_count = 0; + +void TVMPackedFunc_SetupExecs(); + +// Implement TVMModule::GetFunction +// Put implementation in this file so we have seen the TVMPackedFunc +static inline void TVMModule_GetFunction(const char * name, TVMPackedFunc * pf) { + int idx; + memset(pf, 0, sizeof(TVMPackedFunc)); + assert(strlen(name) <= sizeof(pf->name)); + snprintf(pf->name, strlen(name), "%s", name); + pf->Call = TVMPackedFunc_Call; + pf->SetArgs = TVMPackedFunc_SetArgs; + pf->fexec = &TVMNoOperation; + for (idx = 0; idx < GRAPH_RUNTIME_MAX_NODES; idx++) { + if (!strcmp(g_fexecs[idx].name, name)) { + pf->fexec = g_fexecs[idx].fexec; + break; + } + } + if (idx == GRAPH_RUNTIME_MAX_NODES) { + fprintf(stderr, "function handle for %s not found\n", name); + } +} + +#endif // TVM_RUNTIME_CRT_PACKED_FUNC_H_ diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 29ffb5f5f92a..5c00fd9c8896 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -19,7 +19,7 @@ set -e set -u -export PYTHONPATH=python:topi/python:apps/extension/python +export PYTHONPATH=`pwd`/python:`pwd`/topi/python:`pwd`/apps/extension/python export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" export TVM_BIND_THREADS=0 export TVM_NUM_THREADS=2 @@ -30,6 +30,12 @@ find . -type f -path "*.pyc" | xargs rm -f # Test TVM make cython3 +# Test MISRA-C runtime +cd apps/bundle_deploy +rm -rf build +make test +cd ../.. + # Test extern package cd apps/extension rm -rf lib