Skip to content

Commit

Permalink
reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 26, 2020
1 parent bd41f55 commit 3063bc0
Showing 1 changed file with 79 additions and 81 deletions.
160 changes: 79 additions & 81 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,19 +734,13 @@ def _convert_elemwise_input(data, input_type):
}


def is_int_seq(seq):
return len(seq) > 0 and all([isinstance(i, int) for i in seq])

def run_jit_passes(graph):
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)

def parse_inputs(graph_inputs, input_shapes):
ir_inputs = list(graph_inputs)
input_vars = {}

for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
ir_input.setDebugName(input_name)
input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars
def is_int_seq(seq):
return len(seq) > 0 and all([isinstance(i, int) for i in seq])


def get_tensor_and_var(torch_tensor, name):
Expand All @@ -768,13 +762,48 @@ def get_input_names(node):
return [inp.debugName() for inp in node.inputs()]


def get_op_inputs(op_node, outputs, output_index_map):
input_names = [output_index_map[name]
for name in get_input_names(op_node)]
return [outputs[name] for name in input_names]


def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)


def get_all_op_names(graph):
nodes = list(graph.nodes())
return set([node.kind() for node in nodes])


def report_missing_conversion(op_names):
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys())

missing = [op_name for op_name in op_names
if op_name not in known_ops]

if missing:
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)


def getattr_attr_name(node):
attribute_names = node.attributeNames()
assert(len(attribute_names) == 1)
attr_name = node.s(attribute_names[0])
return attr_name


def get_full_attr_name(getattrs):
return ".".join([getattr_attr_name(node) for node in getattrs])


def get_use_chains(root_node, terminate=lambda _: False):
def concat_lists(lists):
return itertools.chain.from_iterable(lists)
Expand Down Expand Up @@ -811,36 +840,6 @@ def terminate(users):
return get_use_chains(root_getattr_node, terminate)


def get_full_attr_name(getattrs):
return ".".join([getattr_attr_name(node) for node in getattrs])


def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
seen = set()

for node in getattr_nodes:
if get_output_name(node) in seen:
continue

for getattrs in get_attr_chains(node):
seen.update(map(get_output_name, getattrs))

full_attr = get_full_attr_name(getattrs)
full_attr_node_name = get_output_name(getattrs[-1])

if full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var

return params, param_tensors


def get_input_types(op_node):
input_list_types = []
for input_node in op_node.inputs():
Expand Down Expand Up @@ -896,6 +895,43 @@ def get_constant(node):
return None


def parse_inputs(graph_inputs, input_shapes):
ir_inputs = list(graph_inputs)
input_vars = {}

for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
ir_input.setDebugName(input_name)
input_vars[input_name] = _expr.var(input_name,
shape=input_shapes[input_name])
return input_vars


def parse_params(graph, state_dict):
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
seen = set()

for node in getattr_nodes:
if get_output_name(node) in seen:
continue

for getattrs in get_attr_chains(node):
seen.update(map(get_output_name, getattrs))

full_attr = get_full_attr_name(getattrs)
full_attr_node_name = get_output_name(getattrs[-1])

if full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var

return params, param_tensors


def parse_ops(nodes):
ops = {}
# Traverse nodes and add to graph
Expand All @@ -911,45 +947,6 @@ def parse_ops(nodes):
return ops


def get_input_node_names(op_node, output_index_map):
return [output_index_map[name] for name in get_input_names(op_node)]


def get_op_inputs(op_node, outputs, output_index_map):
input_names = get_input_node_names(op_node, output_index_map)
return [outputs[name] for name in input_names]


def run_jit_passes(graph):
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)


def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)


def get_all_op_names(graph):
nodes = list(graph.nodes())
return set([node.kind() for node in nodes])


def report_missing_conversion(graph):
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
known_ops += list(_convert_map.keys())

missing = [op_name for op_name in get_all_op_names(graph)
if op_name not in known_ops]

if missing:
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)


def from_pytorch(script_module, input_shapes):
""" Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand All @@ -973,7 +970,8 @@ def from_pytorch(script_module, input_shapes):
"""
graph = script_module.graph.copy()
run_jit_passes(graph)
report_missing_conversion(graph)
op_names = get_all_op_names(graph)
report_missing_conversion(op_names)

params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
Expand Down

0 comments on commit 3063bc0

Please sign in to comment.