Skip to content

Commit

Permalink
[Frontend][TFLite] Add parser support for shape and range (apache#5329)
Browse files Browse the repository at this point in the history
* [Relay][Frontend][TFLite] Add parser support for shape and range

Signed-off-by: Dhruva Ray <[email protected]>

* Incorporated review comments and used new functions

Signed-off-by: Dhruva Ray <[email protected]>

* Few cosmetic changes

Signed-off-by: Dhruva Ray <[email protected]>

* Removed an extra line added by rebase...

Signed-off-by: Dhruva Ray <[email protected]>
  • Loading branch information
dhruvaray authored and trevor-m committed Jun 18, 2020
1 parent b6b241f commit cba2e37
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 24 deletions.
35 changes: 35 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, model, subgraph, exp_tab):
'PAD': self.convert_pad,
'POW': self.convert_pow,
'PRELU': self.convert_prelu,
'RANGE': self.convert_range,
'QUANTIZE': self.convert_quantize,
'REDUCE_ANY': self.convert_reduce_any,
'REDUCE_MAX': self.convert_reduce_max,
Expand All @@ -126,6 +127,7 @@ def __init__(self, model, subgraph, exp_tab):
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'SELECT': self.convert_select,
'SHAPE': self.convert_shape,
'SIN': self.convert_sin,
'SLICE': self.convert_slice,
'SOFTMAX': self.convert_softmax,
Expand Down Expand Up @@ -609,6 +611,39 @@ def convert_tanh(self, op):

return out

def convert_range(self, op):
"""Convert TFLite Range"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"

start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]

expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]]

# out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())

out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)

return out

def convert_shape(self, op):
"""Convert TFLite Shape"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

out = _op.shape_of(self.get_tensor_expr(input_tensors[0]))

return out

def convert_relu(self, op):
"""Convert TFLite ReLU"""
input_tensors = self.get_input_tensors(op)
Expand Down
168 changes: 144 additions & 24 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,34 @@ def get_real_image_object_detection(im_height, im_width):
data = np.reshape(x, (1, im_height, im_width, 3))
return data

def vmobj_to_list(o):
if isinstance(o, tvm.nd.NDArray):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.runtime.container.ADT):
result = []
for f in o:
result.extend(vmobj_to_list(f))
return result
elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if o.constructor.name_hint == 'Cons':
tl = vmobj_to_list(o.fields[1])
hd = vmobj_to_list(o.fields[0])
hd.extend(tl)
return hd
elif o.constructor.name_hint == 'Nil':
return []
elif 'tensor_nil' in o.constructor.name_hint:
return [0]
elif 'tensor' in o.constructor.name_hint:
return [o.fields[0].asnumpy()]
else:
raise RuntimeError("Unknown object type: %s" %
o.constructor.name_hint)
else:
raise RuntimeError("Unknown object type: %s" % type(o))

def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
out_names=None):
out_names=None, mode='graph_runtime'):
""" Generic function to compile on relay and execute on tvm """
# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
Expand All @@ -109,27 +135,43 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
shape_dict=shape_dict,
dtype_dict=dtype_dict)

with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
out_names, num_output)
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list
if mode in ['debug', 'vm']:
ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm")
inputs = []
for param in mod['main'].params:
found = False
for i, n in enumerate(input_node):
if n == param.name_hint:
found = True
inputs.append(tvm.nd.array(input_data[i]))
break
# Interpreter doesn't bind constants, so still need to find in params
if not found:
inputs.append(tvm.nd.array(params[param.name_hint]))
result = ex.evaluate()(*inputs)
return vmobj_to_list(result)
else:
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)

ctx = tvm.context(target, 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
# set inputs
for i, e in enumerate(input_node):
m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype)))

m.set_input(**params)
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format(
out_names, num_output)
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
tvm_output_list.append(tvm_output.asnumpy())
return tvm_output_list


def run_tflite_graph(tflite_model_buf, input_data):
Expand Down Expand Up @@ -160,7 +202,7 @@ def run_tflite_graph(tflite_model_buf, input_data):

def compare_tflite_with_tvm(in_data, in_name, input_tensors,
output_tensors, init_global_variables=False,
out_names=None, quantized=False, input_range=None):
out_names=None, quantized=False, input_range=None, mode='graph_runtime'):
"""Generic function to generate and compare TFLite and TVM output"""
in_data = convert_to_list(in_data)
in_name = convert_to_list(in_name)
Expand Down Expand Up @@ -202,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
continue

tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
num_output=len(out_names), out_names=out_names)
num_output=len(out_names), out_names=out_names, mode=mode)

# WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
# range for the specific operator. While adding test ensure that we aren't getting only clipped values
Expand Down Expand Up @@ -859,6 +901,80 @@ def test_all_resize():
if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()):
_test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False)

#######################################################################
# Range
# -----
def _test_range(start, limit, delta):
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
tf.reset_default_graph()
with tf.Graph().as_default():
start_scalar, limit_scalar, delta_scalar = \
tf.placeholder(dtype=start.dtype, shape=(), name="start"), \
tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \
tf.placeholder(dtype=delta.dtype, shape=(), name="delta")

out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range")

compare_tflite_with_tvm(
[start, limit, delta],
["start", "limit", "delta"],
[start_scalar, limit_scalar, delta_scalar],
[out],
mode="vm",
quantized=False
)

def _test_range_default():
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
tf.reset_default_graph()
with tf.Graph().as_default():
inputs = [
tf.placeholder(dtype=tf.int32, shape=(), name="p1"),
tf.placeholder(dtype=tf.int32, shape=(), name="p2")
]
outputs = [
tf.range(start = inputs[0], limit = inputs[1]), # use default delta
tf.range(start = inputs[1]) # use start as limit with 0 as the first item in the range
]

compare_tflite_with_tvm(
[np.int32(1), np.int32(18)],
["p1", "p2"],
inputs,
outputs,
mode="vm"
)

def test_forward_range():
_test_range(np.int32(1), np.int32(18), np.int32(3))
_test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float
_test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float
_test_range_default()

#######################################################################
# Shape
# -----
def test_forward_shape():
# tflite 1.13 convert method does not accept empty shapes
if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
tf.reset_default_graph()
with tf.Graph().as_default():
data = np.array([1, 18, 3], dtype=np.int32)
start = tf.placeholder(dtype=tf.int32, shape=[], name="start")
limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit")
delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta")
r = tf.range(start, limit, delta, tf.int32, name="range")
out = tf.shape(r, out_type=tf.dtypes.int32)
compare_tflite_with_tvm(
[x for x in np.nditer(data)],
["start", "limit", "delta"],
[start, limit, delta],
[out],
mode="vm"
)

#######################################################################
# Concatenation
# -------------
Expand Down Expand Up @@ -2363,13 +2479,17 @@ def test_forward_mediapipe_hand_landmark():
# Tile
test_forward_tile()

# Query
test_forward_shape()

# Transforms
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_unpack()
test_forward_reshape()
test_all_resize()
test_forward_range()
test_forward_squeeze()
test_forward_slice()
test_forward_topk()
Expand Down

0 comments on commit cba2e37

Please sign in to comment.