Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix node name within if node #3378

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def _populate_input_output_tensor_maps(map_input_tensor_to_node: Dict[str, List[
map_output_tensor_to_node[output] = node

for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
for subnode in getattr(attribute, 'g').node:
OnnxSaver._populate_input_output_tensor_maps(map_input_tensor_to_node, map_output_tensor_to_node,
subnode)
Expand Down Expand Up @@ -800,7 +800,7 @@ def gather_nodes_in_topological_order():

visited.add(id(node))
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
# traversing the list in reverse, see 'NOTE1'
for subnode in reversed(getattr(attribute, 'g').node):
pending_nodes_list.appendleft((subnode, parent_module_name))
Expand Down Expand Up @@ -884,7 +884,7 @@ def _populate_graph_and_output_names_lists(onnx_graph: onnx.GraphProto, graphs_l

for node in onnx_graph.node:
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._populate_graph_and_output_names_lists(attribute.g, graphs_list, output_names_list)

@staticmethod
Expand All @@ -903,7 +903,7 @@ def _get_onnx_node_map(onnx_graph: onnx.GraphProto, onnx_node_map: Dict[Tuple[st

for node in onnx_graph.node:
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._get_onnx_node_map(attribute.g, onnx_node_map)
return onnx_node_map

Expand Down Expand Up @@ -1041,10 +1041,18 @@ def _remove_marked_module_string_from_node_inp_out_names(
"""
if node_output_name_counter is None:
node_output_name_counter = {}

if param_name_to_updated_name is None:
param_name_to_updated_name = {}

# Remove 'marked_module' string from input and output field of onnx.GraphProto
for inp in onnx_graph.input:
updated_name = cls._get_updated_name(inp.name)
inp.name = updated_name
for out in onnx_graph.output:
updated_name = cls._get_updated_name(out.name)
out.name = updated_name

# Remove 'marked_module' string from all node's input and output names.
for node in onnx_graph.node:
for index, param_name in enumerate(node.output):
updated_name = cls._get_updated_name(param_name)
Expand All @@ -1064,9 +1072,10 @@ def _remove_marked_module_string_from_node_inp_out_names(
updated_name = param_name_to_updated_name.get(param_name, updated_name)
node.input[index] = updated_name

# Recursively updates subgraph node's input and output names.
for node in onnx_graph.node:
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
cls._remove_marked_module_string_from_node_inp_out_names(
attribute.g, node_output_name_counter, param_name_to_updated_name
)
Expand Down Expand Up @@ -1127,7 +1136,7 @@ def _remove_detached_nodes_from_onnx_graph(cls, onnx_graph: onnx.GraphProto):

for node in onnx_graph.node:
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._remove_detached_nodes_from_onnx_graph(attribute.g)

@classmethod
Expand Down Expand Up @@ -1213,7 +1222,7 @@ def _populate_start_and_end_marker_maps(start_marker_map: Dict[str, onnx.NodePro
end_marker_map[identifier].append(node)

for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
for subnode in getattr(attribute, 'g').node:
OnnxSaver._populate_start_and_end_marker_maps(start_marker_map, end_marker_map, subnode)

Expand Down Expand Up @@ -1381,7 +1390,7 @@ def _set_output_names(onnx_model: onnx.ModelProto, graphs_list: List[onnx.GraphP

for node in onnx_model.graph.node:
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._set_output_names_for_graph(attribute.g, graphs_list, output_names_for_all_graphs,
map_output_tensor_to_node, map_input_tensor_to_node)

Expand Down Expand Up @@ -1452,7 +1461,7 @@ def _get_all_nodes(onnx_graph: onnx.GraphProto, all_nodes: Union[List[onnx.NodeP
for node in onnx_graph.node:
all_nodes.append(node)
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._get_all_nodes(attribute.g, all_nodes)
return all_nodes

Expand Down Expand Up @@ -1519,7 +1528,7 @@ def _get_all_initializers(onnx_graph: onnx.GraphProto, initializers: Union[List[
initializers.append(initializer)
for node in onnx_graph.node:
for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._get_all_initializers(attribute.g, initializers)
return initializers

Expand Down Expand Up @@ -1547,7 +1556,7 @@ def _populate_node_to_io_tensor_map_and_valid_param_set(
valid_param_set.add(input_tensor)

for attribute in node.attribute:
if getattr(attribute, 'g', None) is not None:
if getattr(attribute, 'g').name != '':
OnnxSaver._populate_node_to_io_tensor_map_and_valid_param_set(attribute.g, initializer_names,
node_to_io_tensor_name_map,
valid_param_set)
Expand Down
39 changes: 39 additions & 0 deletions TrainingExtensions/torch/test/python/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#
# @@-COPYRIGHT-END-@@
# =============================================================================

import contextlib
import copy
import logging
Expand Down Expand Up @@ -1089,3 +1090,41 @@ def test_get_unique_node_output_name(self):
assert (
param_name_to_updated_name[param_name] == '/down_blocks.0/Add_1/Add_output_0_dup1'
)

def test_node_names(self):
""" Check if the 'marked_module' string is removed correctly """
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
return self.conv1(x)


pt_model = Model().eval()
dummy_input = torch.randn(1, 3, 24, 24)
with tempfile.TemporaryDirectory() as tmp_dir:
torch.onnx.export(pt_model.eval(),
dummy_input,
os.path.join(tmp_dir, "model.onnx"),
training=torch.onnx.TrainingMode.EVAL,
export_params=True,
input_names=['input'],
output_names=['output'])
model = onnx.load_model(os.path.join(tmp_dir, "model.onnx"))

# Add 'marked_module' string in input and output field of onnx.GraphProto object.
model.graph.input[0].name = model.graph.input[0].name + '/marked_module'
model.graph.output[0].name = model.graph.input[0].name + '/marked_module'

# An exception should be raised since the node input/output names are not consistent.
from onnx.checker import ValidationError
with pytest.raises(ValidationError):
onnx.checker.check_model(model)

# Remove the 'marked_module' string
OnnxSaver._remove_marked_module_string_from_node_inp_out_names(model.graph)

# model should be consistent.
onnx.checker.check_model(model)
Loading