Skip to content

Commit

Permalink
Only restrict to tflite.Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Ramana Radhakrishnan committed Apr 17, 2020
1 parent a754f01 commit 093c314
Showing 1 changed file with 57 additions and 15 deletions.
72 changes: 57 additions & 15 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,6 @@
from .common import ExprTable
from .common import infer_shape as _infer_shape

# A note on tflite imports. Operator specific imports of modules
# need to be with the operator. General imports that are common across
# multiple operators and which are very common should be in the
# block below.
try:
from tflite.Operator import Operator as Operator
from tflite.TensorType import TensorType as TensorType
from tflite.BuiltinOperator import BuiltinOperator as BuiltinOperator
from tflite.BuiltinOptions import BuiltinOptions as BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType as ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")


__all__ = ['from_tflite']

class TensorWrapper(object):
Expand All @@ -60,6 +46,12 @@ class OperatorConverter(object):
"""Operator Converted for converting TFLite ops to Relay ops"""
def __init__(self, model, subgraph, exp_tab):

try:
from tflite.BuiltinOperator import BuiltinOperator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

self.model = model
self.subgraph = subgraph
Expand Down Expand Up @@ -167,6 +159,10 @@ def convert_op_to_relay(self):
op = self.subgraph.Operators(op_idx)
op_code_str = self.get_op_code_str(op)
output_tensors = self.get_output_tensors(op)
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
ret = self.convert_map[op_code_str](op)
Expand All @@ -181,6 +177,10 @@ def convert_op_to_relay(self):

def get_op_code_str(self, op):
"""Get TFLite ops string representation"""
try:
from tflite.BuiltinOperator import BuiltinOperator
except ImportError:
raise ImportError("The tflite package must be installed")

op_code_list_idx = op.OpcodeIndex()
op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode()
Expand Down Expand Up @@ -236,6 +236,11 @@ def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""
assert isinstance(tensor_wrapper, TensorWrapper)

try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
Expand All @@ -256,6 +261,11 @@ def get_tensor_value(self, tensor_wrapper):

def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
try:
from tflite.TensorType import TensorType
except ImportError:
raise ImportError("The tflite package must be installed")

if tensor_type == TensorType.UINT8:
return "uint8"
if tensor_type == TensorType.FLOAT32:
Expand Down Expand Up @@ -283,7 +293,6 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):

def is_quantized(self, op):
"""Check if an input tensor is quantized."""

input_tensors = self.get_input_tensors(op)
first_tensor = input_tensors[0]
return first_tensor.qnn_params is not None
Expand Down Expand Up @@ -324,6 +333,7 @@ def convert_max_pool2d(self, op):
def convert_reshape(self, op):
"""Convert TFLite reshape"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ReshapeOptions import ReshapeOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -354,6 +364,7 @@ def convert_reshape(self, op):
def _convert_resize(self, method, op):
"""Generic method to Convert TFLite RESIZE operators"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ResizeBilinearOptions import ResizeBilinearOptions
# ResizeNearestNeighborOptions was added in tflite v1.13
tflite_ver = 1120
Expand Down Expand Up @@ -405,7 +416,9 @@ def convert_resize_nearest_neighbor(self, op):
def convert_l2_normalization(self, op):
"""Convert TFLite L2_NORMALIZATION """
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.L2NormOptions import L2NormOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -447,6 +460,7 @@ def convert_l2_normalization(self, op):
def convert_lrn(self, op):
"""Convert TFLite LOCAL_RESPONSE_NORMALIZATION """
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.LocalResponseNormalizationOptions import LocalResponseNormalizationOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -583,6 +597,8 @@ def convert_concatenation(self, op):
"""Convert TFLite concatenation"""
try:
from tflite.ConcatenationOptions import ConcatenationOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -763,6 +779,8 @@ def _convert_elemwise(self, relay_op, op):
from tflite.SubOptions import SubOptions
from tflite.MulOptions import MulOptions
from tflite.DivOptions import DivOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -979,6 +997,7 @@ def convert_zeros_like(self, op):
def _convert_reduce(self, relay_op, op):
"""Generic method to Convert TFLite MEAN operators"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ReducerOptions import ReducerOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -1042,6 +1061,9 @@ def convert_fully_connected(self, op):
"""Convert TFLite fully connected"""
try:
from tflite.FullyConnectedOptions import FullyConnectedOptions
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")

Expand Down Expand Up @@ -1138,6 +1160,7 @@ def convert_fully_connected(self, op):
def convert_squeeze(self, op):
"""Convert TFLite squeeze"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.SqueezeOptions import SqueezeOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand All @@ -1162,6 +1185,10 @@ def convert_squeeze(self, op):

def convert_fused_activation_function(self, in_expr, fused_activation_fn):
"""Convert TFLite fused activation function"""
try:
from tflite.ActivationFunctionType import ActivationFunctionType
except ImportError:
raise ImportError("The tflite package must be installed")
assert fused_activation_fn != ActivationFunctionType.NONE
if fused_activation_fn == ActivationFunctionType.RELU6:
return _op.clip(in_expr, a_min=0, a_max=6)
Expand All @@ -1178,6 +1205,9 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn):
def convert_conv(self, op, conv_type):
"""convolution implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.TensorType import TensorType
from tflite.Conv2DOptions import Conv2DOptions
from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
from tflite.Padding import Padding
Expand Down Expand Up @@ -1343,6 +1373,7 @@ def convert_conv(self, op, conv_type):
def convert_split(self, op):
"""split implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.SplitOptions import SplitOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -1419,6 +1450,7 @@ def convert_transpose(self, op):
def convert_cast(self, op):
"""Convert TFLite CAST"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.CastOptions import CastOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -1469,6 +1501,8 @@ def convert_topk_v2(self, op):
def convert_pool2d(self, op, pool_type):
"""pool2d implementation."""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ActivationFunctionType import ActivationFunctionType
from tflite.Pool2DOptions import Pool2DOptions
from tflite.Padding import Padding
except ImportError:
Expand Down Expand Up @@ -1589,6 +1623,7 @@ def convert_floor_mod(self, op):
def convert_mirror_pad(self, op):
"""Convert TFLite MIRROR_PAD"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.MirrorPadOptions import MirrorPadOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -1624,6 +1659,7 @@ def convert_mirror_pad(self, op):
def convert_pack(self, op):
"""Convert TFLite pack"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.PackOptions import PackOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand All @@ -1648,6 +1684,7 @@ def convert_pack(self, op):
def convert_unpack(self, op):
"""Convert TFLite unpack"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.UnpackOptions import UnpackOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -1788,6 +1825,7 @@ def convert_space_to_batch_nd(self, op):
def convert_depth_to_space(self, op):
"""Convert TFLite DEPTH_TO_SPACE"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.DepthToSpaceOptions import DepthToSpaceOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand All @@ -1810,6 +1848,7 @@ def convert_depth_to_space(self, op):
def convert_space_to_depth(self, op):
"""Convert TFLite SPACE_TO_DEPTH"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.SpaceToDepthOptions import SpaceToDepthOptions
except ImportError:
raise ImportError("The tflite package must be installed")
Expand Down Expand Up @@ -1848,6 +1887,8 @@ def convert_prelu(self, op):
def convert_transpose_conv(self, op):
"""Convert TFLite TRANSPOSE_CONV"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.TransposeConvOptions import TransposeConvOptions
from tflite.Padding import Padding
except ImportError:
Expand Down Expand Up @@ -2207,6 +2248,7 @@ def from_tflite(model, shape_dict, dtype_dict):
try:
import tflite.Model
import tflite.SubGraph
import tflite.BuiltinOperator
except ImportError:
raise ImportError("The tflite package must be installed")
assert isinstance(model, tflite.Model.Model)
Expand Down

0 comments on commit 093c314

Please sign in to comment.