Skip to content

Commit

Permalink
[TFLite Runtime] Fix bug and re-enable RPC execution test (apache#5436)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpiszczek authored and trevor-m committed Jun 18, 2020
1 parent a7207de commit 05029d6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 83 deletions.
8 changes: 7 additions & 1 deletion src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) {
void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) {
const char* buffer = tflite_model_bytes.c_str();
size_t buffer_size = tflite_model_bytes.size();
// The buffer used to construct the model must be kept alive for
// dependent interpreters to be used.
flatBuffersBuffer_ = std::unique_ptr<char[]>(new char[buffer_size]);
std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size);
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size);
tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size);
tflite::ops::builtin::BuiltinOpResolver resolver;
// Build interpreter
TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_);
Expand Down Expand Up @@ -173,5 +177,7 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx
TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = TFLiteRuntimeCreate(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate);
} // namespace runtime
} // namespace tvm
3 changes: 3 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_

#include <dlpack/dlpack.h>
#include <tensorflow/lite/interpreter.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>

Expand Down Expand Up @@ -93,6 +94,8 @@ class TFLiteRuntime : public ModuleNode {
*/
NDArray GetOutput(int index) const;

// Buffer backing the interpreter's model
std::unique_ptr<char[]> flatBuffersBuffer_;
// TFLite interpreter
std::unique_ptr<tflite::Interpreter> interpreter_;
// TVM context
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.opencl";
} else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal";
} else if (target == "tflite") {
f_name = "target.runtime.tflite";
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") {
Expand Down
202 changes: 120 additions & 82 deletions tests/python/contrib/test_tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,92 +14,130 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm import te
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tensorflow as tf
# import tflite_runtime.interpreter as tflite


def skipped_test_tflite_runtime():

def create_tflite_model():
root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)

input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model


def check_local():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)


def check_remote():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)

check_local()
check_remote()


def _create_tflite_model():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)

input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model


@pytest.mark.skip('skip because accessing output tensor is flakey')
def test_local():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

tflite_fname = "model.tflite"
tflite_model = _create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)


def test_remote():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

tflite_fname = "model.tflite"
tflite_model = _create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)

server.terminate()


if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass
test_local()
test_remote()
3 changes: 3 additions & 0 deletions tests/scripts/task_config_build_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake
echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
echo set\(USE_VTA_TSIM ON\) >> config.cmake
echo set\(USE_VTA_FSIM ON\) >> config.cmake
echo set\(USE_TFLITE ON\) >> config.cmake
echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake
echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake

0 comments on commit 05029d6

Please sign in to comment.