Skip to content

Commit

Permalink
[TFLite] pack operation extedned with const args (apache#6984)
Browse files Browse the repository at this point in the history
pack operation now accepts constant arguments
  • Loading branch information
d-smirnov authored and masahi committed Dec 24, 2020
1 parent 293a341 commit a9e04a8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2524,9 +2524,6 @@ def convert_pack(self, op):
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 1, "input tensors should greater than 1"
in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"

Expand All @@ -2535,8 +2532,11 @@ def convert_pack(self, op):
pack_options = PackOptions()
pack_options.Init(op_options.Bytes, op_options.Pos)
pack_axis = pack_options.Axis()
pack_values_count = pack_options.ValuesCount()
assert len(input_tensors) == pack_values_count, "Discordance in input values count"

in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs]
in_exprs = [self.get_tensor_expr(_) for _ in input_tensors]
in_exprs_reshaped = [_op.expand_dims(_, axis=pack_axis, num_newaxis=1) for _ in in_exprs]
out = _op.concatenate(in_exprs_reshaped, pack_axis)
return out

Expand Down
26 changes: 18 additions & 8 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,34 +2750,44 @@ def test_forward_one_hot():
# ----


def _test_pack(data, axis):
def _test_pack(data, is_var, axis):
""" One iteration of pack """

assert len(data) >= 1
assert len(data) == len(is_var)

with tf.Graph().as_default():
in_data = [
array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx))
for idx, tensor in enumerate(data)
array_ops.placeholder(shape=d.shape, dtype=d.dtype, name="in_" + str(idx))
if is_var[idx]
else constant_op.constant(
d, shape=d.shape, dtype=d.dtype, name="in_constant_" + str(idx)
)
for idx, d in enumerate(data)
]
out = array_ops.pack(in_data, axis=axis)
name = ["in_{}:0".format(idx) for idx in range(len(data))]

compare_tflite_with_tvm(data, name, in_data, [out])
out = array_ops.pack(in_data, axis=axis)
name = [_.name for _ in in_data]
compare_tflite_with_tvm(data, name, in_data, [out], experimental_new_converter=True)


def test_forward_pack():
""" Pack """
_test_pack([np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], 1)
_test_pack([np.int32(1), np.int32(5)], [False, False], 0)
_test_pack([np.array([1, 4]), np.array([2, 5]), np.array([3, 6])], [True, False, False], 0)
_test_pack(
[np.arange(6).reshape((1, 2, 1, 3)), np.arange(6).reshape((1, 2, 1, 3))], [True, True], 1
)

_test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], 1)
_test_pack([np.arange(6).reshape((3, 2)), np.arange(6).reshape((3, 2))], [True, True], 1)

_test_pack(
[
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3)),
np.arange(6).reshape((2, 1, 1, 3)),
],
[True, True, True],
1,
)

Expand Down

0 comments on commit a9e04a8

Please sign in to comment.