Skip to content

Commit

Permalink
add _ prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 28, 2020
1 parent 7351c42 commit 3a2cc94
Showing 1 changed file with 83 additions and 82 deletions.
165 changes: 83 additions & 82 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,54 +733,49 @@ def _convert_elemwise_input(data, input_type):
}


def run_jit_passes(graph):
def _run_jit_passes(graph):
""" The inline pass is necessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)


def is_int_seq(seq):
def _is_int_seq(seq):
return len(seq) > 0 and all([isinstance(i, int) for i in seq])


def get_tensor_and_var(torch_tensor, name):
def _get_tensor_and_var(torch_tensor, name):
tensor = tvm.nd.array(torch_tensor.cpu().numpy())
var = _expr.var(name, shape=tensor.shape)
return tensor, var


def get_output_name(node):
def _get_output_name(node):
assert node.outputsSize() == 1
return node.output().debugName()


def get_output_names(node):
def _get_output_names(node):
return [output.debugName() for output in node.outputs()]


def get_input_names(node_or_graph):
def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]


def get_op_inputs(op_node, outputs, output_index_map):
def _get_op_inputs(op_node, outputs, output_index_map):
input_names = [output_index_map[name]
for name in get_input_names(op_node)]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]


def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)


def get_all_op_names(graph):
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)


def report_missing_conversion(op_names):
def _report_missing_conversion(op_names):
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
Expand All @@ -795,66 +790,26 @@ def report_missing_conversion(op_names):
raise NotImplementedError(msg)


def getattr_attr_name(node):
def _getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
attr_name = node.s(attribute_names[0])
return attr_name


def get_full_attr_name(getattrs):
return ".".join([getattr_attr_name(node) for node in getattrs])


def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
See get_attr_chains below for its usage
"""
def concat_lists(lists):
return itertools.chain.from_iterable(lists)

def inner(current, accum):
users = []
for output in current.outputs():
users += [use.user for use in output.uses()]

if not users or terminate(users):
return [accum]

return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])

return inner(root_node, [root_node])


def get_attr_chains(root_getattr_node):
""" Returns chains of attribute access starting from root_getattr_node
For example, given attribute "block", as in "self.block" when "self" points
to the top level torch.nn.Module, it returns lists of attribute "chains",
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
These sets of attributes form full attribute accessors. For example,
"self.block.1", "self.block.2" will return the second and third submodule,
and "self.block.0._packed_params" will return the parameters of the first
submodule.
"""
def terminate(users):
next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
return len(next_attrs) == 0

return get_use_chains(root_getattr_node, terminate)
def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])


def get_input_types(op_node):
def _get_input_types(op_node):
""" Returns a torch type for each input nodes """
input_list_types = []
for input_node in op_node.inputs():
in_ty = input_node.type()
input_node_kind = in_ty.kind()
if input_node_kind == 'TensorType':
if in_ty.scalarType() is None:
input_list_types.append('float')
input_list_types.append(None)
else:
input_list_types.append(in_ty.scalarType().lower())
elif input_node_kind == 'ListType':
Expand All @@ -874,7 +829,7 @@ def get_input_types(op_node):
return input_list_types


def get_constant(node):
def _get_constant(node):
""" Retrieve a constant associated with this prim::Constant node """
attribute_names = node.attributeNames()
num_attributes = len(attribute_names)
Expand Down Expand Up @@ -903,15 +858,15 @@ def get_constant(node):
return None


def get_operator_nodes(nodes):
def _get_operator_nodes(nodes):
""" Returns torch IR nodes that need conversion to Relay """
ops = {}
# Traverse nodes and add to graph
for node in nodes:
if node.outputsSize() > 1:
node_name = "_".join(get_output_names(node))
node_name = "_".join(_get_output_names(node))
else:
node_name = get_output_name(node)
node_name = _get_output_name(node)

if node.kind() != "prim::GetAttr":
ops[node_name] = node
Expand All @@ -930,6 +885,46 @@ def parse_inputs(graph_inputs, input_shapes):
return input_vars


def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
See get_attr_chains below for its usage
"""
def concat_lists(lists):
return itertools.chain.from_iterable(lists)

def inner(current, accum):
users = []
for output in current.outputs():
users += [use.user for use in output.uses()]

if not users or terminate(users):
return [accum]

return concat_lists([inner(nxt, accum + [nxt]) for nxt in users])

return inner(root_node, [root_node])


def get_attr_chains(root_getattr_node):
""" Returns chains of attribute access starting from root_getattr_node
For example, given attribute "block", as in "self.block" when "self" points
to the top level torch.nn.Module, it returns lists of attribute "chains",
e.g. ['block', '2'], ['block', '1'], ['block', '0', '_packed_params']
These sets of attributes form full attribute accessors. For example,
"self.block.1", "self.block.2" will return the second and third submodule,
and "self.block.0._packed_params" will return the parameters of the first
submodule.
"""
def terminate(users):
next_attrs = [user for user in users if user.kind() == "prim::GetAttr"]
return len(next_attrs) == 0

return get_use_chains(root_getattr_node, terminate)


def parse_params(graph, state_dict):
"""
Return Relay vars and TVM NDArrays for input parameters
Expand All @@ -941,19 +936,19 @@ def parse_params(graph, state_dict):
seen = set()

for node in getattr_nodes:
if get_output_name(node) in seen:
if _get_output_name(node) in seen:
continue

for getattrs in get_attr_chains(node):
seen.update(map(get_output_name, getattrs))
seen.update(map(_get_output_name, getattrs))

full_attr = get_full_attr_name(getattrs)
full_attr_node_name = get_output_name(getattrs[-1])
full_attr = _getattr_full_name(getattrs)
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr in state_dict:
torch_tensor = state_dict[full_attr]
tensor, var = get_tensor_and_var(torch_tensor,
full_attr_node_name)
tensor, var = _get_tensor_and_var(torch_tensor,
full_attr_node_name)
param_tensors[full_attr_node_name] = tensor
params[full_attr_node_name] = var

Expand All @@ -964,35 +959,41 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators.items():
operator = op_node.kind()
inputs = get_op_inputs(op_node, outputs, output_index_map)
inputs = _get_op_inputs(op_node, outputs, output_index_map)

if operator == "prim::Constant":
output_index_map[node_name] = len(outputs)
outputs.append(get_constant(op_node))
elif operator == 'prim::ListConstruct' and is_int_seq(inputs):
outputs.append(_get_constant(op_node))
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs)
outputs.append(_expr.var(node_name, shape=inputs))
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs)
outputs.append(inputs)
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = get_output_names(op_node)
update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, get_input_types(op_node)))
outputs.append(relay_op(inputs, _get_input_types(op_node)))

return outputs[output_index_map[ret_name]]


def get_all_op_names(graph):
""" Return all operator names in the input graph """
nodes = list(graph.nodes())
return set(node.kind() for node in nodes)


def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = get_input_names(script_module.graph.copy())
ir_inputs = _get_input_names(script_module.graph.copy())
return ir_inputs[1:] # remove self at the 0th arg


Expand All @@ -1019,9 +1020,9 @@ def from_pytorch(script_module, input_shapes):
Dict of converted parameters stored in tvm.runtime.ndarray format
"""
graph = script_module.graph.copy()
run_jit_passes(graph)
_run_jit_passes(graph)
op_names = get_all_op_names(graph)
report_missing_conversion(op_names)
_report_missing_conversion(op_names)

params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
Expand All @@ -1030,9 +1031,9 @@ def from_pytorch(script_module, input_shapes):
input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
ret_name = get_input_names(graph.return_node())[0]
ret_name = _get_input_names(graph.return_node())[0]

body = parse_operators(get_operator_nodes(graph.nodes()), outputs,
body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
func = tvm.relay.Function(_analysis.free_vars(body), body)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
Expand Down

0 comments on commit 3a2cc94

Please sign in to comment.