Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite Runtime] Fix bug and re-enable RPC execution test #5436

Merged
merged 6 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
// The buffer used to construct the model must be kept alive for
// dependent interpreters to be used.
flatBuffersBuffer_ = std::unique_ptr<char[]>(new char[buffer_size]);
std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
Expand Down Expand Up @@ -173,5 +177,7 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx
TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = TFLiteRuntimeCreate(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate);
} // namespace runtime
} // namespace tvm
3 changes: 3 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_

#include <dlpack/dlpack.h>
#include <tensorflow/lite/interpreter.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>

Expand Down Expand Up @@ -93,6 +94,8 @@ class TFLiteRuntime : public ModuleNode {
*/
NDArray GetOutput(int index) const;

// Buffer backing the interpreter's model
std::unique_ptr<char[]> flatBuffersBuffer_;
// TFLite interpreter
std::unique_ptr<tflite::Interpreter> interpreter_;
// TVM context
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.opencl";
} else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal";
} else if (target == "tflite") {
f_name = "target.runtime.tflite";
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") {
Expand Down
202 changes: 120 additions & 82 deletions tests/python/contrib/test_tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,92 +14,130 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm import te
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tensorflow as tf
# import tflite_runtime.interpreter as tflite


def skipped_test_tflite_runtime():

def create_tflite_model():
root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)

input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model


def check_local():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)


def check_remote():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)

check_local()
check_remote()


def _create_tflite_model():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)

input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model


@pytest.mark.skip('skip because accessing output tensor is flakey')
def test_local():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

tflite_fname = "model.tflite"
tflite_model = _create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)


def test_remote():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

tflite_fname = "model.tflite"
tflite_model = _create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)

server.terminate()


if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass
test_local()
test_remote()
3 changes: 3 additions & 0 deletions tests/scripts/task_config_build_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake
echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
echo set\(USE_VTA_TSIM ON\) >> config.cmake
echo set\(USE_VTA_FSIM ON\) >> config.cmake
echo set\(USE_TFLITE ON\) >> config.cmake
echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake
echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake