Skip to content

Commit

Permalink
[Torch] Add initial control flow support (apache#4964)
Browse files Browse the repository at this point in the history
* Add support for prim::If and prim::Loop with test cases

* rebase and fix tests

* add some comments

* simplifying, fix float cast

* parse -> convert

* recursivly retrive ops in get_all_op_names

* use multiple return values from block correctly, simplify loop convert

* choose dtype properly for zeros and ones

* simplifying, replace convert_inputs with _get_relay_input_vars

* fix for while loop with non input dependent init cond

* add assert on loop var update

* move the condition around

* better testing for seg models

* rebase fix, disable inception v3 in quant test as it is too slow to
load with torch-1.4 + torchvision 0.5

* simplify and add more comparison op converter
  • Loading branch information
masahi authored and Trevor Morris committed Apr 16, 2020
1 parent 105a366 commit df79344
Show file tree
Hide file tree
Showing 3 changed files with 385 additions and 38 deletions.
223 changes: 192 additions & 31 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""PT: PyTorch frontend."""
import itertools
import logging
import sys

import numpy as np

Expand All @@ -29,6 +30,7 @@
from .. import analysis as _analysis
from .. import expr as _expr
from .. import op as _op
from ..loops import while_loop
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
Expand Down Expand Up @@ -107,9 +109,8 @@ def _select():
def _impl(inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
index = int(inputs[2])

return _op.transform.take(data, _expr.const(index, dtype="int32"), axis=dim)
index = _wrap_const(inputs[2])
return _op.transform.take(data, index, axis=dim)
return _impl

def _ones():
Expand All @@ -126,7 +127,10 @@ def _impl(inputs, input_types):
else:
assert "data type {} could not be parsed in ones op" % (type(data))

return _op.full(_expr.const(1), shape, dtype=_convert_data_type(input_types[0]))
dtype_map = {6: "float32", 3: "int32"}
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id])
return _impl

def _zeros():
Expand All @@ -143,7 +147,10 @@ def _impl(inputs, input_types):
else:
assert "data type {} could not be parsed in zeros op" % (type(data))

return _op.full(_expr.const(0), shape, dtype=_convert_data_type(input_types[0]))
dtype_map = {6: "float32", 3: "int32"}
dtype_id = inputs[1]
assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id])
return _impl

def _relu():
Expand Down Expand Up @@ -222,12 +229,10 @@ def _impl(inputs, input_types):
else:
assert "data type {} could not be parsed in conv op" % (type(weight))

# TODO: Add reshape when channel multiplier > 1. Pending PR #4644
channels = weight_shape[0]
groups = int(inputs[8])

if groups > 1:
# in torch, groups == in_channels for depth wise conv
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
weight = _op.transform.reshape(weight, new_weight_shape)
Expand Down Expand Up @@ -496,7 +501,7 @@ def _impl(inputs, input_types):
return _impl

def _reduce(name):
def _impl(inputs, attrs, params):
def _impl(inputs, input_types):
data = inputs[0]
return get_relay_op(name)(data)
return _impl
Expand Down Expand Up @@ -714,7 +719,6 @@ def func(x):

return _impl


def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
Expand All @@ -724,6 +728,29 @@ def _impl(inputs, input_types):
return inputs[0]
return _impl

def _neg():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.negative(data)
return _impl

def _tanh():
def _impl(inputs, input_types):
data = inputs[0]
return _op.tensor.tanh(data)
return _impl

def _Bool():
def _impl(inputs, input_types):
assert len(inputs) == 1
return inputs[0]
return _impl

def _Float():
def _impl(inputs, input_types):
assert len(inputs) == 1
return _op.cast(inputs[0], "float32")
return _impl

# Helper functions for operator implementation

Expand Down Expand Up @@ -780,6 +807,11 @@ def _convert_elemwise_input(data, input_type):
else:
return data

def _wrap_const(c):
if not isinstance(c, _expr.Expr) and not isinstance(c, list):
return _expr.const(c)
return c

# Operator mappings

_convert_map = {
Expand Down Expand Up @@ -845,7 +877,16 @@ def _convert_elemwise_input(data, input_type):
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
"aten::expand_as" : _expand_as()
"aten::expand_as" : _expand_as(),
"aten::lt" : _elemwise("less"),
"aten::gt" : _elemwise("greater"),
"aten::le" : _elemwise("less_equal"),
"aten::ge" : _elemwise("greater_equal"),
"aten::ne" : _elemwise("not_equal"),
"aten::Bool" : _Bool(),
"aten::Float" : _Float(),
"aten::neg" : _neg(),
"aten::tanh" : _tanh(),
}


Expand Down Expand Up @@ -894,7 +935,8 @@ def _report_missing_conversion(op_names):
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
"prim::TupleConstruct", "prim::TupleUnpack",
"prim::If", "prim::Loop"]
known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())

Expand Down Expand Up @@ -939,9 +981,13 @@ def _get_input_types(op_node):
input_node_kind = in_ty.kind()
if input_node_kind == 'TensorType':
if in_ty.scalarType() is None:
input_list_types.append(None)
# Tensor's type can be unknown if we use torch.jit.script(...)
# Defaults to float for now
logging.warning("Untyped Tensor found, assume it is float")
input_list_types.append("float")
else:
input_list_types.append(in_ty.scalarType().lower())

elif input_node_kind == 'ListType':
input_list_types.append(str(in_ty.getElementType()).lower())
elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
Expand Down Expand Up @@ -1004,15 +1050,10 @@ def _get_operator_nodes(nodes):
return ops


def parse_inputs(graph_inputs, input_shapes):
""" Return Relay vars from torch input vars """
ir_inputs = list(graph_inputs)
input_vars = {}

for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
def _get_relay_input_vars(input_shapes):
""" Return Relay vars from input shapes """
return {iname: _expr.var(iname, shape=ishape)
for iname, ishape in input_shapes.items()}


def get_use_chains(root_node, terminate=lambda _: False):
Expand Down Expand Up @@ -1055,7 +1096,7 @@ def terminate(users):
return get_use_chains(root_getattr_node, terminate)


def parse_params(graph, state_dict):
def convert_params(graph, state_dict):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
Expand Down Expand Up @@ -1090,7 +1131,109 @@ def parse_params(graph, state_dict):
return params, param_tensors, packed_param_map


def parse_operators(operators, outputs, output_index_map, ret_name):
def convert_block(block, outputs, output_index_map):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, output_index_map, ret_names)


def convert_if(if_node, outputs, output_index_map):
""" Translate Torch prim::If to Relay If """
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs, output_index_map)
false_branch = convert_block(blocks[1], outputs, output_index_map)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])


def convert_loop(loop_node, outputs, output_index_map):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
inode = ivalue.node()
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
assert var_name in output_index_map
return _wrap_const(outputs[output_index_map[var_name]])

# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
# The first input: %max_trip_count
# The second input: %initial_condition
# The rest of input: loop variables
max_loop_count = get_input(0)
init_cond = get_input(1)
num_loop_var = len(list(loop_node.inputs())) - 2
init_vals = [get_input(i + 2) for i in range(num_loop_var)]

# while loop has always max_loop_count being int64 max
# max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
_get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)

body_block = list(loop_node.blocks())[0]
block_input_names = _get_input_names(body_block)

def cond(*current_vals):
i = current_vals[0]

if is_while_loop:
return _op.equal(i, _expr.const(True, 'bool'))

return _op.less(i, max_loop_count)

def body(*current_vals):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
outputs[output_index_map[iname]] = current_vals[i]

block_outputs = convert_block(body_block, outputs, output_index_map)

if not is_while_loop:
# iter var increment implicit in torch, so do it manually
# for while loop, block_outputs[0] is already a boolean,
# the result of termination check
incr = _expr.const(1, dtype="int32")
block_outputs[0] = current_vals[0] + incr

return block_outputs

def get_var(name, val):
if isinstance(val, _expr.Constant):
return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
return _expr.var(name)

if is_while_loop:
loop_iter_dtype = "bool"
# while loop with non input dependent condition such as while i < 10:
# init_cond is int, need to cast to bool to type check
if isinstance(init_cond, _expr.Constant):
init_cond = _op.cast(init_cond, "bool")
init_loop_iter_val = init_cond
else:
loop_iter_dtype = "int32"
# always count from 0
init_loop_iter_val = _expr.const(0, dtype="int32")

name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)

loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
loop_val = loop(init_loop_iter_val, *init_vals)

# The first element is a loop counter or boolean condition, ignore it
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]


def convert_operators(operators, outputs, output_index_map, ret_names):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
Expand All @@ -1110,17 +1253,35 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, output_index_map)
output_index_map[node_name] = len(outputs)
outputs.append(if_out)
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, output_index_map)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
_update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node)))

return outputs[output_index_map[ret_name]]
return [_wrap_const(outputs[output_index_map[ret_name]])
for ret_name in ret_names]


def get_all_op_names(graph):
""" Return all operator names in the input graph """
return set(node.kind() for node in graph.nodes())
nodes = list(graph.nodes())
prim_with_blocks = ["prim::If", "prim::Loop"]
for prim in prim_with_blocks:
prim_nodes = graph.findAllNodes(prim, recurse=True)
for prim_node in prim_nodes:
for block in prim_node.blocks():
nodes += block.nodes()
return set(node.kind() for node in nodes)


def get_graph_input_names(script_module):
Expand Down Expand Up @@ -1167,14 +1328,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
_check_input_names(script_module, input_shapes)

params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
param_vars, tensors, packed_param_map = parse_params(graph, params)
input_vars = _get_relay_input_vars(input_shapes)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())[0]
ret_name = _get_input_names(graph.return_node())

# For quantized models
if "aten::quantize_per_tensor" in op_names:
Expand All @@ -1186,8 +1347,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)

body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])

return _module.IRModule.from_expr(func), tvm_params
3 changes: 2 additions & 1 deletion tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ def get_imagenet_input():
qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
# disable inception test for now, since loading it takes ~5min on torchvision-0.5
#("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
]

Expand Down
Loading

0 comments on commit df79344

Please sign in to comment.