-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend][Torch] Fix up graph input handling #5204
Changes from 3 commits
80984c1
de2e6f1
48ed1e1
bbe567a
fe4331f
1adb56a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1007,16 +1007,13 @@ def _get_input_names(node_or_graph): | |
return [inp.debugName() for inp in node_or_graph.inputs()] | ||
|
||
|
||
def _get_op_inputs(op_node, outputs, output_index_map): | ||
input_names = [output_index_map[name] | ||
for name in _get_input_names(op_node)] | ||
return [outputs[name] for name in input_names] | ||
def _get_op_inputs(op_node, input_vars): | ||
return [input_vars[name] for name in _get_input_names(op_node)] | ||
|
||
|
||
def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map): | ||
for output_name, output in name_output_pairs: | ||
output_index_map[output_name] = len(outputs) | ||
outputs.append(output) | ||
def _update_inputs_from_pairs(name_input_pairs, input_vars): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need this function anymore. Dict's |
||
for input_name, inp in name_input_pairs: | ||
input_vars[input_name] = inp | ||
|
||
|
||
def _report_missing_conversion(op_names): | ||
|
@@ -1036,18 +1033,31 @@ def _report_missing_conversion(op_names): | |
raise NotImplementedError(msg) | ||
|
||
|
||
def _check_input_names(script_module, input_shapes): | ||
""" Check the graph inputs match the inputs """ | ||
ir_inputs = get_graph_input_names(script_module) | ||
|
||
for ir_input in ir_inputs: | ||
if ir_input not in input_shapes: | ||
msg = "Missing graph input {} in input_shapes".format(ir_input) | ||
raise RuntimeError(msg) | ||
|
||
for input_name in input_shapes: | ||
if input_name not in ir_inputs: | ||
msg = "Unused graph input {} in input_shapes".format(input_name) | ||
def _check_inputs(graph, input_shapes): | ||
""" | ||
Check the graph inputs match the expected number of inputs | ||
and are in the correct format | ||
""" | ||
ir_inputs = _get_graph_input_names(graph) | ||
|
||
if not isinstance(input_shapes, list): | ||
msg = "Graph inputs input_shapes should be list" | ||
raise RuntimeError(msg) | ||
missing_inputs = len(ir_inputs) - len(input_shapes) | ||
if missing_inputs > 0: | ||
msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs) | ||
raise RuntimeError(msg) | ||
|
||
for num, inp in enumerate(input_shapes): | ||
if num < len(ir_inputs): | ||
if not isinstance(inp, tuple): | ||
msg = "Graph input {} is not a tuple".format(num) | ||
raise RuntimeError(msg) | ||
if (len(inp) != 2 or not isinstance(inp[0], str)): | ||
msg = "Graph input {} is not valid, expected ('name', shape)".format(inp) | ||
raise RuntimeError(msg) | ||
else: | ||
msg = "Unused graph input {} in input_shapes".format(inp) | ||
logging.warning(msg) | ||
|
||
|
||
|
@@ -1139,10 +1149,20 @@ def _get_operator_nodes(nodes): | |
return ops | ||
|
||
|
||
def _get_relay_input_vars(input_shapes): | ||
""" Return Relay vars from input shapes """ | ||
return {iname: _expr.var(iname, shape=ishape) | ||
for iname, ishape in input_shapes.items()} | ||
def _get_relay_input_vars(graph, input_shapes): | ||
""" | ||
Return Relay vars from input shapes and create entries based on | ||
expected graph inputs - to allow translation | ||
""" | ||
input_vars = {} | ||
ir_inputs = _get_graph_input_names(graph) | ||
for idx, ir_input in enumerate(ir_inputs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about
|
||
name, shape = input_shapes[idx] | ||
inp = _expr.var(name, shape=shape) | ||
# Translate from graph input to user input name | ||
input_vars[ir_input] = inp | ||
|
||
return input_vars | ||
|
||
|
||
def get_use_chains(root_node, terminate=lambda _: False): | ||
|
@@ -1220,33 +1240,33 @@ def convert_params(graph, state_dict): | |
return params, param_tensors, packed_param_map | ||
|
||
|
||
def convert_block(block, outputs, output_index_map): | ||
def convert_block(block, input_vars): | ||
""" Translate Torch "Block", used for prim::If and prim::Loop """ | ||
ops = _get_operator_nodes(block.nodes()) | ||
ret_names = _get_input_names(block.returnNode()) | ||
return convert_operators(ops, outputs, output_index_map, ret_names) | ||
return convert_operators(ops, input_vars, ret_names) | ||
|
||
|
||
def convert_if(if_node, outputs, output_index_map): | ||
def convert_if(if_node, input_vars): | ||
""" Translate Torch prim::If to Relay If """ | ||
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]] | ||
cond = input_vars[if_node.inputsAt(0).debugName()] | ||
blocks = list(if_node.blocks()) | ||
true_branch = convert_block(blocks[0], outputs, output_index_map) | ||
false_branch = convert_block(blocks[1], outputs, output_index_map) | ||
true_branch = convert_block(blocks[0], input_vars) | ||
false_branch = convert_block(blocks[1], input_vars) | ||
assert len(true_branch) == 1 and len(false_branch) == 1 | ||
return _expr.If(cond, true_branch[0], false_branch[0]) | ||
|
||
|
||
def convert_loop(loop_node, outputs, output_index_map): | ||
def convert_loop(loop_node, input_vars): | ||
""" Translate Torch prim::Loop to Relay while_loop """ | ||
def get_input(index): | ||
ivalue = loop_node.inputsAt(index) | ||
inode = ivalue.node() | ||
if inode.kind() == "prim::Constant": | ||
return _expr.const(_get_constant(inode)) | ||
var_name = ivalue.debugName() | ||
assert var_name in output_index_map | ||
return _wrap_const(outputs[output_index_map[var_name]]) | ||
assert var_name in input_vars | ||
return _wrap_const(input_vars[var_name]) | ||
|
||
# Refer to the spec for prim::Loop below | ||
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops | ||
|
@@ -1278,9 +1298,9 @@ def body(*current_vals): | |
# Update loop variables using the prev iteration outputs | ||
assert len(current_vals) == len(block_input_names) | ||
for (i, iname) in enumerate(block_input_names): | ||
outputs[output_index_map[iname]] = current_vals[i] | ||
input_vars[iname] = current_vals[i] | ||
|
||
block_outputs = convert_block(body_block, outputs, output_index_map) | ||
block_outputs = convert_block(body_block, input_vars) | ||
|
||
if not is_while_loop: | ||
# iter var increment implicit in torch, so do it manually | ||
|
@@ -1310,7 +1330,7 @@ def get_var(name, val): | |
|
||
name_val_pairs = list(zip(block_input_names, | ||
[init_loop_iter_val] + init_vals)) | ||
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map) | ||
_update_inputs_from_pairs(name_val_pairs, input_vars) | ||
|
||
loop_iter_var = _expr.var(block_input_names[0], shape=(), | ||
dtype=loop_iter_dtype) | ||
|
@@ -1322,36 +1342,32 @@ def get_var(name, val): | |
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)] | ||
|
||
|
||
def convert_operators(operators, outputs, output_index_map, ret_names): | ||
def convert_operators(operators, input_vars, ret_names): | ||
""" Convert each Torch IR operators to Relay equivalent """ | ||
for node_name, op_node in operators: | ||
operator = op_node.kind() | ||
inputs = _get_op_inputs(op_node, outputs, output_index_map) | ||
inputs = _get_op_inputs(op_node, input_vars) | ||
|
||
if operator == "prim::Constant": | ||
output_index_map[node_name] = len(outputs) | ||
outputs.append(_get_constant(op_node)) | ||
input_vars[node_name] = _get_constant(op_node) | ||
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs): | ||
output_index_map[node_name] = len(outputs) | ||
outputs.append(_expr.var(node_name, shape=inputs)) | ||
input_vars[node_name] = _expr.var(node_name, shape=inputs) | ||
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']: | ||
output_index_map[node_name] = len(outputs) | ||
outputs.append(inputs) | ||
input_vars[node_name] = inputs | ||
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']: | ||
assert len(inputs) == 1 | ||
unpacked_names = _get_output_names(op_node) | ||
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]), | ||
outputs, output_index_map) | ||
_update_inputs_from_pairs(zip(unpacked_names, inputs[0]), | ||
input_vars) | ||
elif operator == "prim::If": | ||
if_out = convert_if(op_node, outputs, output_index_map) | ||
output_index_map[node_name] = len(outputs) | ||
outputs.append(if_out) | ||
if_out = convert_if(op_node, input_vars) | ||
input_vars[node_name] = if_out | ||
elif operator == "prim::Loop": | ||
loop_out = convert_loop(op_node, outputs, output_index_map) | ||
loop_out = convert_loop(op_node, input_vars) | ||
unpacked_names = _get_output_names(op_node) | ||
assert len(loop_out) == len(unpacked_names) | ||
_update_outputs_from_pairs(zip(unpacked_names, loop_out), | ||
outputs, output_index_map) | ||
_update_inputs_from_pairs(zip(unpacked_names, loop_out), | ||
input_vars) | ||
else: | ||
relay_op = _convert_map[operator] | ||
relay_out = relay_op(inputs, _get_input_types(op_node)) | ||
|
@@ -1360,13 +1376,12 @@ def convert_operators(operators, outputs, output_index_map, ret_names): | |
# This is for torch operators that return multiple outputs | ||
# See _adaptive_max_2d above for example | ||
out_names = _get_output_names(op_node) | ||
_update_outputs_from_pairs(zip(out_names, relay_out), | ||
outputs, output_index_map) | ||
_update_inputs_from_pairs(zip(out_names, relay_out), | ||
input_vars) | ||
else: | ||
output_index_map[node_name] = len(outputs) | ||
outputs.append(relay_out) | ||
input_vars[node_name] = relay_out | ||
|
||
return [_wrap_const(outputs[output_index_map[ret_name]]) | ||
return [_wrap_const(input_vars[ret_name]) | ||
for ret_name in ret_names] | ||
|
||
|
||
|
@@ -1382,11 +1397,11 @@ def get_all_op_names(graph): | |
return set(node.kind() for node in nodes) | ||
|
||
|
||
def get_graph_input_names(script_module): | ||
""" Use this function to set the keys for input_shapes""" | ||
# It seems variable names could change the first time a copy is made | ||
# Use the copy of the graph here to prevent troubles later | ||
ir_inputs = _get_input_names(script_module.graph.copy()) | ||
def _get_graph_input_names(graph): | ||
""" Get the graph input names (use after graph copy and run jit passes) """ | ||
# Variable names could change the first time a copy is made and after | ||
# _run_jit_passes is called, expected that those functions already invoked | ||
ir_inputs = _get_input_names(graph) | ||
return ir_inputs[1:] # remove self at the 0th arg | ||
|
||
|
||
|
@@ -1423,30 +1438,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None): | |
|
||
op_names = get_all_op_names(graph) | ||
_report_missing_conversion(op_names) | ||
_check_input_names(script_module, input_shapes) | ||
_check_inputs(graph, input_shapes) | ||
|
||
params = script_module.state_dict() | ||
input_vars = _get_relay_input_vars(input_shapes) | ||
input_vars = _get_relay_input_vars(graph, input_shapes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
param_vars, tensors, packed_param_map = convert_params(graph, params) | ||
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()} | ||
|
||
input_vars.update(param_vars) | ||
outputs = list(input_vars.values()) | ||
output_index_map = dict(zip(input_vars.keys(), range(len(outputs)))) | ||
ret_name = _get_input_names(graph.return_node()) | ||
|
||
# For quantized models | ||
if "aten::quantize_per_tensor" in op_names: | ||
weight_quant_params = qnn_torch.get_weight_quant_params(script_module) | ||
qnn_torch.add_input_quant_params_to_op_inputs(graph) | ||
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map, | ||
packed_param_map, | ||
weight_quant_params) | ||
qnn_torch.add_quant_params_to_inputs(input_vars, | ||
packed_param_map, | ||
weight_quant_params) | ||
qnn_torch.add_quant_params(tvm_params, weight_quant_params) | ||
_convert_map.update(qnn_torch.convert_map) | ||
|
||
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs, | ||
output_index_map, ret_name) | ||
ret = convert_operators(_get_operator_nodes(graph.nodes()), | ||
input_vars, ret_name) | ||
|
||
if isinstance(ret[0], list): | ||
ret[0] = _expr.Tuple(ret[0]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,20 +101,19 @@ def get_weight_quant_params(script_module): | |
return quant_params | ||
|
||
|
||
def add_quant_params_to_outputs(outputs, output_index_map, | ||
packed_param_map, quant_params): | ||
def add_quant_params_to_inputs(input_vars, packed_param_map, | ||
quant_params): | ||
""" | ||
Add quant params to outputs so that they can be referenced by other | ||
Add quant params to inputs so that they can be referenced by other | ||
ops later. Weights are quantized here. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For L104 and L107, please keep |
||
""" | ||
for node_name, packed_param_name in packed_param_map.items(): | ||
qparam = quant_params[packed_param_name] | ||
output_index_map[node_name] = len(outputs) | ||
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale, | ||
qparam.zero_point, out_dtype="int8", | ||
axis=0) | ||
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var) | ||
outputs.append(param_tup) | ||
input_vars[node_name] = param_tup | ||
|
||
|
||
def _get_quant_param_for_input(input_value): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,6 @@ | |
from tvm import relay | ||
from tvm.contrib import graph_runtime | ||
from tvm.relay.testing.config import ctx_list | ||
from tvm.relay.frontend.pytorch import get_graph_input_names | ||
|
||
|
||
sys.setrecursionlimit(10000) | ||
|
@@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[], | |
else: | ||
trace = trace.cpu() | ||
|
||
input_names = get_graph_input_names(trace) | ||
input_shapes = dict(zip(input_names, | ||
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] | ||
input_shapes = list(zip(input_names, | ||
[inp.shape for inp in baseline_input])) | ||
mod, params = relay.frontend.from_pytorch(trace, input_shapes, | ||
custom_convert_map) | ||
|
@@ -890,11 +889,12 @@ def test_3d_models(): | |
|
||
def verify_script_model(pt_model, ishapes): | ||
script_module = torch.jit.script(pt_model) | ||
input_names = get_graph_input_names(script_module) | ||
input_shapes = dict(zip(input_names, ishapes)) | ||
|
||
inputs = [torch.randn(input_shapes[input_name], dtype=torch.float) | ||
for input_name in input_names] | ||
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)] | ||
input_shapes = list(zip(input_names, ishapes)) | ||
|
||
inputs = [torch.randn(shape, dtype=torch.float) | ||
for name, shape in input_shapes] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
mod, params = relay.frontend.from_pytorch(script_module, input_shapes) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_vars
->outputs
Because inputs are not
relay.Var
.