Skip to content

Commit

Permalink
use input names that come with torch IR
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 26, 2020
1 parent b8d334c commit 10f4f9a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
28 changes: 18 additions & 10 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,8 @@ def get_output_names(node):
return [output.debugName() for output in node.outputs()]


def get_input_names(node):
return [inp.debugName() for inp in node.inputs()]
def get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]


def get_op_inputs(op_node, outputs, output_index_map):
Expand All @@ -781,7 +781,7 @@ def get_all_op_names(graph):


def report_missing_conversion(op_names):
"""Check if all ops in an input graph are supported by TVM"""
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
Expand Down Expand Up @@ -828,7 +828,7 @@ def inner(current, accum):


def get_attr_chains(root_getattr_node):
"""Returns chains of attribute access starting from root_getattr_node
""" Returns chains of attribute access starting from root_getattr_node
For example, given attribute "block", as in "self.block" when "self" points
to the top level torch.nn.Module, it returns lists of attribute "chains",
Expand All @@ -847,7 +847,7 @@ def terminate(users):


def get_input_types(op_node):
"""Returns a torch type for each input nodes"""
""" Returns a torch type for each input nodes """
input_list_types = []
for input_node in op_node.inputs():
in_ty = input_node.type()
Expand Down Expand Up @@ -875,7 +875,7 @@ def get_input_types(op_node):


def get_constant(node):
""" Retrive a constant associated with this prim::Constant node"""
""" Retrive a constant associated with this prim::Constant node """
attribute_names = node.attributeNames()
num_attributes = len(attribute_names)

Expand Down Expand Up @@ -904,12 +904,11 @@ def get_constant(node):


def parse_inputs(graph_inputs, input_shapes):
""" Return Relay vars from torch input vars"""
""" Return Relay vars from torch input vars """
ir_inputs = list(graph_inputs)
input_vars = {}

for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
ir_input.setDebugName(input_name)
input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
Expand Down Expand Up @@ -961,6 +960,14 @@ def parse_ops(nodes):
return ops


def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = get_input_names(script_module.graph.copy())
return ir_inputs[1:] # remove self at the 0th arg


def from_pytorch(script_module, input_shapes):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand All @@ -971,16 +978,17 @@ def from_pytorch(script_module, input_shapes):
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input))
shape : Dictionary of input dimensions
input_shape : Dictionary of input dimensions
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
Returns
-------
mod : tvm.relay.Module
The module that optimizations will be performed on.
params : dict of str to tvm.runtime.NDArray
Dict of converted parameters stored in tvm.ndarray format
Dict of converted parameters stored in tvm.runtime.ndarray format
"""
graph = script_module.graph.copy()
run_jit_passes(graph)
Expand Down
9 changes: 5 additions & 4 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names


sys.setrecursionlimit(10000)

Expand Down Expand Up @@ -167,16 +169,15 @@ def verify_model(model_name, input_data=[]):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
output_shapes = [out.shape for out in baseline_outputs]
dtype = "float32"
input_name = "input0"
input_shapes = {input_name: list(baseline_input.shape)}
trace = torch.jit.trace(baseline_model, baseline_input).float().eval()

if torch.cuda.is_available():
trace = trace.cuda()
else:
trace = trace.cpu()

input_name = get_graph_input_names(trace)[0] # only one input
input_shapes = {input_name: list(baseline_input.shape)}
mod, params = relay.frontend.from_pytorch(trace, input_shapes)
compiled_input = {input_name: tvm.nd.array(baseline_input.cpu().numpy())}

Expand Down
14 changes: 7 additions & 7 deletions tutorials/frontend/from_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@
be unstable.
"""

# tvm, relay
import tvm
from tvm import relay

# numpy, packaging
import numpy as np
from packaging import version

from tvm.contrib.download import download_testdata
from tvm.relay.frontend.pytorch import get_graph_input_names

# PyTorch imports
import torch
Expand Down Expand Up @@ -91,7 +90,8 @@
# Import the graph to Relay
# -------------------------
# Convert PyTorch graph to Relay graph.
shape_dict = {'img': img.shape}
input_name = get_graph_input_names(scripted_model)[0] # only one input
shape_dict = {input_name: img.shape}
mod, params = relay.frontend.from_pytorch(scripted_model,
shape_dict)

Expand All @@ -116,12 +116,12 @@
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
# Set inputs
m.set_input('img', tvm.nd.array(img.astype(dtype)))
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
m.set_input(**params)
# Execute
m.run()
# Get outputs
tvm_output = m.get_output(0, tvm.nd.empty(((1, 1000)), 'float32'))
tvm_output = m.get_output(0)

#####################################################################
# Look up synset name
Expand Down Expand Up @@ -163,4 +163,4 @@
torch_class_key = class_id_to_key[top1_torch]

print('Relay top-1 id: {}, class name: {}'.format(top1_tvm, key_to_classname[tvm_class_key]))
print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))
print('Torch top-1 id: {}, class name: {}'.format(top1_torch, key_to_classname[torch_class_key]))

0 comments on commit 10f4f9a

Please sign in to comment.