Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Torch] Fix up graph input handling #5204

Merged
merged 6 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 84 additions & 71 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,16 +1007,13 @@ def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]


def _get_op_inputs(op_node, outputs, output_index_map):
input_names = [output_index_map[name]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]
def _get_op_inputs(op_node, input_vars):
Copy link
Member

@masahi masahi Apr 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_vars -> outputs

Because inputs are not relay.Var.

return [input_vars[name] for name in _get_input_names(op_node)]


def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)
def _update_inputs_from_pairs(name_input_pairs, input_vars):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this function anymore. Dict's update method can be used.

for input_name, inp in name_input_pairs:
input_vars[input_name] = inp


def _report_missing_conversion(op_names):
Expand All @@ -1036,18 +1033,31 @@ def _report_missing_conversion(op_names):
raise NotImplementedError(msg)


def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module)

for ir_input in ir_inputs:
if ir_input not in input_shapes:
msg = "Missing graph input {} in input_shapes".format(ir_input)
raise RuntimeError(msg)

for input_name in input_shapes:
if input_name not in ir_inputs:
msg = "Unused graph input {} in input_shapes".format(input_name)
def _check_inputs(graph, input_shapes):
"""
Check the graph inputs match the expected number of inputs
and are in the correct format
"""
ir_inputs = _get_graph_input_names(graph)

if not isinstance(input_shapes, list):
msg = "Graph inputs input_shapes should be list"
raise RuntimeError(msg)
missing_inputs = len(ir_inputs) - len(input_shapes)
if missing_inputs > 0:
msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
raise RuntimeError(msg)

for num, inp in enumerate(input_shapes):
if num < len(ir_inputs):
if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
if (len(inp) != 2 or not isinstance(inp[0], str)):
msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
raise RuntimeError(msg)
else:
msg = "Unused graph input {} in input_shapes".format(inp)
logging.warning(msg)


Expand Down Expand Up @@ -1139,10 +1149,20 @@ def _get_operator_nodes(nodes):
return ops


def _get_relay_input_vars(input_shapes):
""" Return Relay vars from input shapes """
return {iname: _expr.var(iname, shape=ishape)
for iname, ishape in input_shapes.items()}
def _get_relay_input_vars(graph, input_shapes):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
for idx, ir_input in enumerate(ir_inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about

for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
    ...

name, shape = input_shapes[idx]
inp = _expr.var(name, shape=shape)
# Translate from graph input to user input name
input_vars[ir_input] = inp

return input_vars


def get_use_chains(root_node, terminate=lambda _: False):
Expand Down Expand Up @@ -1220,33 +1240,33 @@ def convert_params(graph, state_dict):
return params, param_tensors, packed_param_map


def convert_block(block, outputs, output_index_map):
def convert_block(block, input_vars):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, output_index_map, ret_names)
return convert_operators(ops, input_vars, ret_names)


def convert_if(if_node, outputs, output_index_map):
def convert_if(if_node, input_vars):
""" Translate Torch prim::If to Relay If """
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
cond = input_vars[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs, output_index_map)
false_branch = convert_block(blocks[1], outputs, output_index_map)
true_branch = convert_block(blocks[0], input_vars)
false_branch = convert_block(blocks[1], input_vars)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])


def convert_loop(loop_node, outputs, output_index_map):
def convert_loop(loop_node, input_vars):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
inode = ivalue.node()
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
assert var_name in output_index_map
return _wrap_const(outputs[output_index_map[var_name]])
assert var_name in input_vars
return _wrap_const(input_vars[var_name])

# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
Expand Down Expand Up @@ -1278,9 +1298,9 @@ def body(*current_vals):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
outputs[output_index_map[iname]] = current_vals[i]
input_vars[iname] = current_vals[i]

block_outputs = convert_block(body_block, outputs, output_index_map)
block_outputs = convert_block(body_block, input_vars)

if not is_while_loop:
# iter var increment implicit in torch, so do it manually
Expand Down Expand Up @@ -1310,7 +1330,7 @@ def get_var(name, val):

name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
_update_inputs_from_pairs(name_val_pairs, input_vars)

loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
Expand All @@ -1322,36 +1342,32 @@ def get_var(name, val):
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]


def convert_operators(operators, outputs, output_index_map, ret_names):
def convert_operators(operators, input_vars, ret_names):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map)
inputs = _get_op_inputs(op_node, input_vars)

if operator == "prim::Constant":
output_index_map[node_name] = len(outputs)
outputs.append(_get_constant(op_node))
input_vars[node_name] = _get_constant(op_node)
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs)
outputs.append(_expr.var(node_name, shape=inputs))
input_vars[node_name] = _expr.var(node_name, shape=inputs)
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs)
outputs.append(inputs)
input_vars[node_name] = inputs
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
_update_inputs_from_pairs(zip(unpacked_names, inputs[0]),
input_vars)
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, output_index_map)
output_index_map[node_name] = len(outputs)
outputs.append(if_out)
if_out = convert_if(op_node, input_vars)
input_vars[node_name] = if_out
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, output_index_map)
loop_out = convert_loop(op_node, input_vars)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
_update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map)
_update_inputs_from_pairs(zip(unpacked_names, loop_out),
input_vars)
else:
relay_op = _convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node))
Expand All @@ -1360,13 +1376,12 @@ def convert_operators(operators, outputs, output_index_map, ret_names):
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(out_names, relay_out),
outputs, output_index_map)
_update_inputs_from_pairs(zip(out_names, relay_out),
input_vars)
else:
output_index_map[node_name] = len(outputs)
outputs.append(relay_out)
input_vars[node_name] = relay_out

return [_wrap_const(outputs[output_index_map[ret_name]])
return [_wrap_const(input_vars[ret_name])
for ret_name in ret_names]


Expand All @@ -1382,11 +1397,11 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)


def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = _get_input_names(script_module.graph.copy())
def _get_graph_input_names(graph):
""" Get the graph input names (use after graph copy and run jit passes) """
# Variable names could change the first time a copy is made and after
# _run_jit_passes is called, expected that those functions already invoked
ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg


Expand Down Expand Up @@ -1423,30 +1438,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
_check_input_names(script_module, input_shapes)
_check_inputs(graph, input_shapes)

params = script_module.state_dict()
input_vars = _get_relay_input_vars(input_shapes)
input_vars = _get_relay_input_vars(graph, input_shapes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_vars -> outputs

param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = _get_input_names(graph.return_node())

# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params_to_inputs(input_vars,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)

ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
ret = convert_operators(_get_operator_nodes(graph.nodes()),
input_vars, ret_name)

if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,19 @@ def get_weight_quant_params(script_module):
return quant_params


def add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map, quant_params):
def add_quant_params_to_inputs(input_vars, packed_param_map,
quant_params):
"""
Add quant params to outputs so that they can be referenced by other
Add quant params to inputs so that they can be referenced by other
ops later. Weights are quantized here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For L104 and L107, please keep outputs

"""
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8",
axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup)
input_vars[node_name] = param_tup


def _get_quant_param_for_input(input_value):
Expand Down
7 changes: 3 additions & 4 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata


Expand All @@ -39,7 +38,7 @@ def torch_version_check():

def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = {input_name: ishape}
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

with relay.build_config(opt_level=3):
Expand Down Expand Up @@ -287,7 +286,7 @@ def test_quantized_modules():
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()

input_name = get_graph_input_names(script_module)[0]
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
Expand Down Expand Up @@ -383,7 +382,7 @@ def get_imagenet_input():
with torch.no_grad():
pt_result = script_module(pt_inp).numpy()

input_name = get_graph_input_names(script_module)[0]
input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp)
runtime.run()
Expand Down
14 changes: 7 additions & 7 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names


sys.setrecursionlimit(10000)
Expand Down Expand Up @@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[],
else:
trace = trace.cpu()

input_names = get_graph_input_names(trace)
input_shapes = dict(zip(input_names,
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = list(zip(input_names,
[inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
Expand Down Expand Up @@ -890,11 +889,12 @@ def test_3d_models():

def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))

inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
for input_name in input_names]
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))

inputs = [torch.randn(shape, dtype=torch.float)
for name, shape in input_shapes]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for shape in ishapes


mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

Expand Down