From 5ea5e93910ac4a0b78f70459c905a59d45ba0909 Mon Sep 17 00:00:00 2001 From: Rishabh Jain <56974688+jainris@users.noreply.github.com> Date: Tue, 11 Aug 2020 13:35:55 +0530 Subject: [PATCH] [TFLite] Implemented EXPAND_DIMS Operator for TFLite. (#6243) --- python/tvm/relay/frontend/tflite.py | 26 +++++++++ tests/python/frontend/tflite/test_forward.py | 56 ++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f168f1b65a947..11d657610c879 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -84,6 +84,7 @@ def __init__(self, model, subgraph, exp_tab): 'ELU': self.convert_elu, 'EQUAL': self.convert_equal, 'EXP': self.convert_exp, + 'EXPAND_DIMS': self.convert_expand_dims, 'FILL': self.convert_fill, 'FLOOR_DIV': self.convert_floor_div, 'FLOOR_MOD': self.convert_floor_mod, @@ -2904,6 +2905,31 @@ def convert_detection_postprocess(self, op): ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) return ret + def convert_expand_dims(self, op): + """Convert TFLite EXPAND_DIMS""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + if input_tensors[0].qnn_params: + # Check that input and output tensor have same qnn params. + output_tensors = self.get_output_tensors(op) + assert self.has_same_qnn_params(input_tensors[0], output_tensors[0]), \ + "TFLite EXPAND_DIMS requires input and output tensors' \ + scale and zero points to be equal" + + input_expr = self.get_tensor_expr(input_tensors[0]) + axis = self.get_tensor_value(input_tensors[1]) + if isinstance(axis, np.ndarray): + assert len(axis) == 1, "only one value is expected." + axis = int(axis) + + ndims = len(input_tensors[0].tensor.ShapeAsNumpy()) + assert (-1-ndims <= axis <= ndims), "axis out of range" + + out = _op.expand_dims(input_expr, axis, 1) + + return out + def convert_one_hot(self, op): """Convert TFLite ONE_HOT""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 2e57175a969af..33ac6d4a15de2 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2030,6 +2030,61 @@ def test_forward_padv2(): np.uint8(10)], quantized=True) +####################################################################### +# EXPAND_DIMS +# ----------- + +def _test_expand_dims(input_shape, input_type, axis, quantized=False): + """ One iteration of EXPAND_DIMS """ + with tf.Graph().as_default(): + axis= ops.convert_to_tensor(axis, dtype=axis.dtype) + + if quantized: + # ignoring input_type as quantized requires uint8 + input = np.random.uniform(0, 256, input_shape).astype('uint8') + in_input = tf.placeholder(dtype='float32', shape=input.shape, name="input") + + input_range = {'q_input': (-100, 100)} + inq_input = tf.quantization.fake_quant_with_min_max_args( + in_input, + min=-100, + max=100, + name="q_input") + + out = array_ops.expand_dims(inq_input, axis=axis) + out = tf.quantization.fake_quant_with_min_max_args( + out, + min=-100, + max=100, + name="out") + + compare_tflite_with_tvm( + [input], + ["q_input"], + [inq_input], + [out], + quantized=True, + input_range=input_range) + else: + input = np.random.uniform(-100, 100, input_shape).astype(input_type) + in_input = tf.placeholder(dtype=input.dtype, shape=input.shape, name="input") + + out = array_ops.expand_dims(in_input, axis=axis) + + compare_tflite_with_tvm( + [input], + ["input"], + [in_input], + [out]) + +def test_forward_expand_dims(): + """ EXPAND_DIMS """ + for quantized in [False, True]: + _test_expand_dims((6, 2, 7, 5), 'float32', np.int32(0), quantized=quantized) + _test_expand_dims((1, 2, 3), 'int32', np.int32(-2), quantized=quantized) + _test_expand_dims((2, 4, 5), 'float32', np.array([1], dtype=np.int32), quantized=quantized) + + ####################################################################### # ONE_HOT # ------- @@ -3021,6 +3076,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_select() test_forward_quantize_dequantize() test_forward_arg_min_max() + test_forward_expand_dims() # NN test_forward_convolution()