From 75543d87f6a076af606d51b370686bf817d612d9 Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Thu, 14 May 2020 15:40:13 -0700 Subject: [PATCH] Re-enable tflite runtime unit tests with guard --- src/runtime/contrib/tflite/tflite_runtime.cc | 3 + src/runtime/module.cc | 2 + tests/python/contrib/test_tflite_runtime.py | 189 +++++++++++-------- tests/scripts/task_config_build_cpu.sh | 3 + 4 files changed, 115 insertions(+), 82 deletions(-) diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index a40fd04959f8..fa5bb66e9b5d 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -177,5 +177,8 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = TFLiteRuntimeCreate(args[0], args[1]); }); + +TVM_REGISTER_GLOBAL("target.runtime.tflite") +.set_body_typed(TFLiteRuntime); } // namespace runtime } // namespace tvm diff --git a/src/runtime/module.cc b/src/runtime/module.cc index be75ff265ccb..46ef6fab082b 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) { f_name = "device_api.opencl"; } else if (target == "mtl" || target == "metal") { f_name = "device_api.metal"; + } else if (target == "tflite") { + f_name = "target.runtime.tflite"; } else if (target == "vulkan") { f_name = "device_api.vulkan"; } else if (target == "stackvm") { diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 8c883b031a89..91803d9232d2 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -14,92 +14,117 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm from tvm import te import numpy as np from tvm import rpc from tvm.contrib import util, tflite_runtime -# import tensorflow as tf -# import tflite_runtime.interpreter as tflite - - -def skipped_test_tflite_runtime(): - - def create_tflite_model(): - root = tf.Module() - root.const = tf.constant([1., 2.], tf.float32) - root.f = tf.function(lambda x: root.const * x) - - input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) - concrete_func = root.f.get_concrete_function(input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - tflite_model = converter.convert() - return tflite_model - - - def check_local(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via tvm tflite runtime - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input)) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - - def check_remote(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via remote tvm tflite runtime - server = rpc.Server("localhost") - remote = rpc.connect(server.host, server.port) - ctx = remote.cpu(0) - a = remote.upload(tflite_model_path) - - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - check_local() - check_remote() + + +def _create_tflite_model(): + root = tf.Module() + root.const = tf.constant([1., 2.], tf.float32) + root.f = tf.function(lambda x: root.const * x) + + input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) + concrete_func = root.f.get_concrete_function(input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + return tflite_model + + +@pytest.mark.skip('skip because accessing output tensor is flakey') +def test_local(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via tvm tflite runtime + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + +def test_remote(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via remote tvm tflite runtime + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + server.terminate() + if __name__ == "__main__": - # skipped_test_tflite_runtime() - pass + test_local() + test_remote() diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 9c1cf2870399..ce545bde6609 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_TFLITE ON\) >> config.cmake +echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake +echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake