Skip to content

Commit

Permalink
Re-enable tflite runtime unit tests with guard
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpiszczek committed May 14, 2020
1 parent 82f6bbb commit 75543d8
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 82 deletions.
3 changes: 3 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,8 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx
TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = TFLiteRuntimeCreate(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("target.runtime.tflite")
.set_body_typed(TFLiteRuntime);
} // namespace runtime
} // namespace tvm
2 changes: 2 additions & 0 deletions src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ bool RuntimeEnabled(const std::string& target) {
f_name = "device_api.opencl";
} else if (target == "mtl" || target == "metal") {
f_name = "device_api.metal";
} else if (target == "tflite") {
f_name = "target.runtime.tflite";
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") {
Expand Down
189 changes: 107 additions & 82 deletions tests/python/contrib/test_tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,92 +14,117 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm import te
import numpy as np
from tvm import rpc
from tvm.contrib import util, tflite_runtime
# import tensorflow as tf
# import tflite_runtime.interpreter as tflite


def skipped_test_tflite_runtime():

def create_tflite_model():
root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)

input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model


def check_local():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)


def check_remote():
tflite_fname = "model.tflite"
tflite_model = create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tflite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)

check_local()
check_remote()


def _create_tflite_model():
root = tf.Module()
root.const = tf.constant([1., 2.], tf.float32)
root.f = tf.function(lambda x: root.const * x)

input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32)
concrete_func = root.f.get_concrete_function(input_signature)
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
return tflite_model


@pytest.mark.skip('skip because accessing output tensor is flakey')
def test_local():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

tflite_fname = "model.tflite"
tflite_model = _create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via tvm tflite runtime
with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)


def test_remote():
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return

try:
import tensorflow as tf
except ImportError:
print('skip because tensorflow not installed...')
return

tflite_fname = "model.tflite"
tflite_model = _create_tflite_model()
temp = util.tempdir()
tflite_model_path = temp.relpath(tflite_fname)
open(tflite_model_path, 'wb').write(tflite_model)

# inference via tflite interpreter python apis
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_shape = input_details[0]['shape']
tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], tflite_input)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details[0]['index'])

# inference via remote tvm tflite runtime
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, 'rb') as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0)))
runtime.invoke()
out = runtime.get_output(0)
np.testing.assert_equal(out.asnumpy(), tflite_output)

server.terminate()


if __name__ == "__main__":
# skipped_test_tflite_runtime()
pass
test_local()
test_remote()
3 changes: 3 additions & 0 deletions tests/scripts/task_config_build_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake
echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake
echo set\(USE_VTA_TSIM ON\) >> config.cmake
echo set\(USE_VTA_FSIM ON\) >> config.cmake
echo set\(USE_TFLITE ON\) >> config.cmake
echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake
echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake

0 comments on commit 75543d8

Please sign in to comment.