Skip to content

Commit

Permalink
Incorporated review comments and used new functions
Browse files Browse the repository at this point in the history
Signed-off-by: Dhruva Ray <[email protected]>
  • Loading branch information
dhruvaray committed May 8, 2020
1 parent 8769235 commit 86619b9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 28 deletions.
30 changes: 3 additions & 27 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,57 +608,33 @@ def convert_tanh(self, op):
def convert_range(self, op):
"""Convert TFLite Range"""
try:
from tflite.Operator import Operator
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized RANGE operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"

start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
expressions = []

for t in [start, limit, delta]:
if self.has_expr(t.tensor_idx):
expressions.append(self.get_expr(t.tensor_idx))
else:
tensor_type = self.get_tensor_type_str(t.tensor.Type())
tensor_value = self.get_tensor_value(t)
expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))
expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]]

#out type inference
# out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())

#put type here form op
out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)

return out

def convert_shape(self, op):
"""Convert TFLite Shape"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized SHAPE operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"

out = _op.shape_of(self.get_expr(input_tensors[0].tensor_idx))
out = _op.shape_of(self.get_tensor_expr(input_tensors[0]))

return out

Expand Down
3 changes: 2 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
continue

tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device,
num_output=len(out_names), out_names=out_names,mode=mode)
num_output=len(out_names), out_names=out_names, mode=mode)

# WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output
# range for the specific operator. While adding test ensure that we aren't getting only clipped values
Expand Down Expand Up @@ -910,6 +910,7 @@ def test_forward_shape():
mode="vm",
quantized=False
)

#######################################################################
# Concatenation
# -------------
Expand Down

0 comments on commit 86619b9

Please sign in to comment.