Skip to content

Commit

Permalink
[TFLite] Add transpose_conv to TFLite parser (apache#4440)
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and Xingyu Zhou committed Dec 13, 2019
1 parent 9cd965a commit 6b92fd6
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 1 deletion.
81 changes: 80 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument, too-many-lines
"""Tensorflow lite frontend."""
from __future__ import absolute_import as _abs
import math
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(self, model, subgraph, exp_tab):
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -1370,6 +1371,84 @@ def convert_prelu(self, op):

return out

def convert_transpose_conv(self, op):
"""Convert TFLite TRANSPOSE_CONV"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.Operator import Operator
from tflite.TransposeConvOptions import TransposeConvOptions
from tflite.Padding import Padding
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"

# Input (data) Tensor. NHWC layout
input_tensor = input_tensors[2]
_, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
# Weights tensor. TFLite uses OHWI layout
weights_tensor = input_tensors[1]
out_channels, kernel_h, kernel_w, in_channels = weights_tensor.tensor.ShapeAsNumpy()
assert input_c == in_channels, \
"Input channel in the filter should match to channel in the input"
# output_shape Tensor. NHWC layout
output_shape_tensor = input_tensors[0]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
op_options = op.BuiltinOptions()
deconv_options = TransposeConvOptions()
deconv_options.Init(op_options.Bytes, op_options.Pos)

padding = deconv_options.Padding()
stride_h = deconv_options.StrideH()
stride_w = deconv_options.StrideW()
assert padding in (Padding.VALID, Padding.SAME), \
'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)

# Data
in_expr = self.get_expr(input_tensor.tensor_idx)

# Weights
weights_tensor_type = weights_tensor.tensor.Type()
# weights tensor type should be UINT8 (quantization) or FLOAT32
assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
weight_value_ohwi = self.get_tensor_value(weights_tensor)
# Relay kernel_layout should be OIHW
# Relay weights layout should be different from kernel_layout - it should be IOHW
weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw, dtype=weight_tensor_type_str)

# Output shape value
output_shape_value = self.get_tensor_value(output_shape_tensor)
# Relay expects filter output channel to match to output tensor channel.
assert out_channels == output_shape_value[3], \
"Output channel in the filter should match to channel in the output_shape"

# TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the same here
if padding == Padding.SAME:
assert (kernel_h, kernel_w) == (1, 1), \
"SAME padding is supported for kernel (1,1) only"

out = _op.nn.conv2d_transpose(in_expr, weight_expr_iohw,
strides=(stride_h, stride_w),
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str)

return out

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down
55 changes: 55 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,60 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)


#######################################################################
# Transpose Convolution
# ---------------------

def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding):
""" One iteration of transpose convolution with given shapes and attributes """

total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
# in_filter layout is HWOI
out = nn_ops.conv2d_transpose(in_data,
in_filter,
output_shape=output_shape,
strides=strides,
padding=padding)
data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])


def test_forward_transpose_conv():
# kernel 3x3, padding VALID
_test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], 'VALID')

# kernel 2x2, padding VALID
_test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], 'VALID')

# kernel 1x1, padding VALID
_test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'VALID')

# kernel 1x1, padding SAME
_test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'SAME')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'SAME')


#######################################################################
# Reshape
# -------
Expand Down Expand Up @@ -1232,6 +1286,7 @@ def test_forward_mediapipe_hand_landmark():

# NN
test_forward_convolution()
test_forward_transpose_conv()
test_forward_logistic()
test_forward_pooling()
test_forward_softmax()
Expand Down

0 comments on commit 6b92fd6

Please sign in to comment.