From 10f4f9a07d6e9da717dc1accbbeda3cb2f174acb Mon Sep 17 00:00:00 2001 From: masahi Date: Wed, 26 Feb 2020 17:04:10 +0900 Subject: [PATCH] use input names that come with torch IR --- python/tvm/relay/frontend/pytorch.py | 28 ++++++++++++------- tests/python/frontend/pytorch/test_forward.py | 9 +++--- tutorials/frontend/from_pytorch.py | 14 +++++----- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2526d71f6dc9f..3f8d31d54dac8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -759,8 +759,8 @@ def get_output_names(node): return [output.debugName() for output in node.outputs()] -def get_input_names(node): - return [inp.debugName() for inp in node.inputs()] +def get_input_names(node_or_graph): + return [inp.debugName() for inp in node_or_graph.inputs()] def get_op_inputs(op_node, outputs, output_index_map): @@ -781,7 +781,7 @@ def get_all_op_names(graph): def report_missing_conversion(op_names): - """Check if all ops in an input graph are supported by TVM""" + """ Check if all ops in an input graph are supported by TVM """ known_ops = ["prim::Constant", "prim::GetAttr", "prim::ListConstruct", "prim::ListUnpack", "prim::TupleConstruct", "prim::TupleUnpack"] @@ -828,7 +828,7 @@ def inner(current, accum): def get_attr_chains(root_getattr_node): - """Returns chains of attribute access starting from root_getattr_node + """ Returns chains of attribute access starting from root_getattr_node For example, given attribute "block", as in "self.block" when "self" points to the top level torch.nn.Module, it returns lists of attribute "chains", @@ -847,7 +847,7 @@ def terminate(users): def get_input_types(op_node): - """Returns a torch type for each input nodes""" + """ Returns a torch type for each input nodes """ input_list_types = [] for input_node in op_node.inputs(): in_ty = input_node.type() @@ -875,7 +875,7 @@ def get_input_types(op_node): def get_constant(node): - """ Retrive a constant associated with this prim::Constant node""" + """ Retrive a constant associated with this prim::Constant node """ attribute_names = node.attributeNames() num_attributes = len(attribute_names) @@ -904,12 +904,11 @@ def get_constant(node): def parse_inputs(graph_inputs, input_shapes): - """ Return Relay vars from torch input vars""" + """ Return Relay vars from torch input vars """ ir_inputs = list(graph_inputs) input_vars = {} for input_name, ir_input in zip(input_shapes, ir_inputs[1:]): - ir_input.setDebugName(input_name) input_vars[input_name] = _expr.var(input_name, shape=input_shapes[input_name]) return input_vars @@ -961,6 +960,14 @@ def parse_ops(nodes): return ops +def get_graph_input_names(script_module): + """ Use this function to set the keys for input_shapes""" + # It seems variable names could change the first time a copy is made + # Use the copy of the graph here to prevent troubles later + ir_inputs = get_input_names(script_module.graph.copy()) + return ir_inputs[1:] # remove self at the 0th arg + + def from_pytorch(script_module, input_shapes): """ Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -971,8 +978,9 @@ def from_pytorch(script_module, input_shapes): TorchScripted PyTorch graph Note: We currently only support traces (ie: torch.jit.trace(model, input)) - shape : Dictionary of input dimensions + input_shape : Dictionary of input dimensions Graph level input shape dictionary + The keys should be the same one returned by get_graph_input_names(...) above Returns ------- @@ -980,7 +988,7 @@ def from_pytorch(script_module, input_shapes): The module that optimizations will be performed on. params : dict of str to tvm.runtime.NDArray - Dict of converted parameters stored in tvm.ndarray format + Dict of converted parameters stored in tvm.runtime.ndarray format """ graph = script_module.graph.copy() run_jit_passes(graph) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 54b89eceef613..a9594ac789ba0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -30,6 +30,8 @@ from tvm import relay from tvm.contrib import graph_runtime from tvm.relay.testing.config import ctx_list +from tvm.relay.frontend.pytorch import get_graph_input_names + sys.setrecursionlimit(10000) @@ -167,16 +169,15 @@ def verify_model(model_name, input_data=[]): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: baseline_outputs = (baseline_outputs.float().cpu().numpy(),) - output_shapes = [out.shape for out in baseline_outputs] - dtype = "float32" - input_name = "input0" - input_shapes = {input_name: list(baseline_input.shape)} trace = torch.jit.trace(baseline_model, baseline_input).float().eval() + if torch.cuda.is_available(): trace = trace.cuda() else: trace = trace.cpu() + input_name = get_graph_input_names(trace)[0] # only one input + input_shapes = {input_name: list(baseline_input.shape)} mod, params = relay.frontend.from_pytorch(trace, input_shapes) compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())} diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index c280c259c1fe4..503f64a4e7d90 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -41,14 +41,13 @@ be unstable. """ -# tvm, relay import tvm from tvm import relay -# numpy, packaging import numpy as np -from packaging import version + from tvm.contrib.download import download_testdata +from tvm.relay.frontend.pytorch import get_graph_input_names # PyTorch imports import torch @@ -91,7 +90,8 @@ # Import the graph to Relay # ------------------------- # Convert PyTorch graph to Relay graph. -shape_dict = {'img': img.shape} +input_name = get_graph_input_names(scripted_model)[0] # only one input +shape_dict = {input_name: img.shape} mod, params = relay.frontend.from_pytorch(scripted_model, shape_dict) @@ -116,12 +116,12 @@ dtype = 'float32' m = graph_runtime.create(graph, lib, ctx) # Set inputs -m.set_input('img', tvm.nd.array(img.astype(dtype))) +m.set_input(input_name, tvm.nd.array(img.astype(dtype))) m.set_input(**params) # Execute m.run() # Get outputs -tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32')) +tvm_output = m.get_output(0) ##################################################################### # Look up synset name @@ -163,4 +163,4 @@ torch_class_key = class_id_to_key[top1_torch] print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key])) -print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key])) \ No newline at end of file +print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))