From 86619b93f55d3ceab90a6b94f5bfdaed14bcdbdf Mon Sep 17 00:00:00 2001 From: Dhruva Ray Date: Fri, 8 May 2020 16:24:35 +0530 Subject: [PATCH] Incorporated review comments and used new functions Signed-off-by: Dhruva Ray --- python/tvm/relay/frontend/tflite.py | 30 ++------------------ tests/python/frontend/tflite/test_forward.py | 3 +- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d5190438fc567..f33feb7acc262 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -608,57 +608,33 @@ def convert_tanh(self, op): def convert_range(self, op): """Convert TFLite Range""" try: - from tflite.Operator import Operator from tflite.TensorType import TensorType except ImportError: raise ImportError("The tflite package must be installed") - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - 'TFlite quantized RANGE operator is not supported yet.') - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be 3" start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] - expressions = [] - for t in [start, limit, delta]: - if self.has_expr(t.tensor_idx): - expressions.append(self.get_expr(t.tensor_idx)) - else: - tensor_type = self.get_tensor_type_str(t.tensor.Type()) - tensor_value = self.get_tensor_value(t) - expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type)) + expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]] - #out type inference + # out type inference if delta.tensor.Type() == TensorType.FLOAT32: out_type = self.get_tensor_type_str(delta.tensor.Type()) else: out_type = self.get_tensor_type_str(start.tensor.Type()) - #put type here form op out = _op.arange(expressions[0], expressions[1], expressions[2], out_type) return out def convert_shape(self, op): """Convert TFLite Shape""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - if self.is_quantized(op): - raise tvm.error.OpNotImplemented( - 'TFlite quantized SHAPE operator is not supported yet.') - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" - out = _op.shape_of(self.get_expr(input_tensors[0].tensor_idx)) + out = _op.shape_of(self.get_tensor_expr(input_tensors[0])) return out diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 351194b60dc2b..60123bb9748ad 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -244,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, continue tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device, - num_output=len(out_names), out_names=out_names,mode=mode) + num_output=len(out_names), out_names=out_names, mode=mode) # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values @@ -910,6 +910,7 @@ def test_forward_shape(): mode="vm", quantized=False ) + ####################################################################### # Concatenation # -------------