diff --git a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py index 67fc5b4def..4fa6332179 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py @@ -543,7 +543,7 @@ def _populate_input_output_tensor_maps(map_input_tensor_to_node: Dict[str, List[ map_output_tensor_to_node[output] = node for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': for subnode in getattr(attribute, 'g').node: OnnxSaver._populate_input_output_tensor_maps(map_input_tensor_to_node, map_output_tensor_to_node, subnode) @@ -800,7 +800,7 @@ def gather_nodes_in_topological_order(): visited.add(id(node)) for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': # traversing the list in reverse, see 'NOTE1' for subnode in reversed(getattr(attribute, 'g').node): pending_nodes_list.appendleft((subnode, parent_module_name)) @@ -884,7 +884,7 @@ def _populate_graph_and_output_names_lists(onnx_graph: onnx.GraphProto, graphs_l for node in onnx_graph.node: for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._populate_graph_and_output_names_lists(attribute.g, graphs_list, output_names_list) @staticmethod @@ -903,7 +903,7 @@ def _get_onnx_node_map(onnx_graph: onnx.GraphProto, onnx_node_map: Dict[Tuple[st for node in onnx_graph.node: for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._get_onnx_node_map(attribute.g, onnx_node_map) return onnx_node_map @@ -1041,10 +1041,18 @@ def _remove_marked_module_string_from_node_inp_out_names( """ if node_output_name_counter is None: node_output_name_counter = {} - if param_name_to_updated_name is None: param_name_to_updated_name = {} + # Remove 'marked_module' string from input and output field of onnx.GraphProto + for inp in onnx_graph.input: + updated_name = cls._get_updated_name(inp.name) + inp.name = updated_name + for out in onnx_graph.output: + updated_name = cls._get_updated_name(out.name) + out.name = updated_name + + # Remove 'marked_module' string from all node's input and output names. for node in onnx_graph.node: for index, param_name in enumerate(node.output): updated_name = cls._get_updated_name(param_name) @@ -1064,9 +1072,10 @@ def _remove_marked_module_string_from_node_inp_out_names( updated_name = param_name_to_updated_name.get(param_name, updated_name) node.input[index] = updated_name + # Recursively updates subgraph node's input and output names. for node in onnx_graph.node: for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': cls._remove_marked_module_string_from_node_inp_out_names( attribute.g, node_output_name_counter, param_name_to_updated_name ) @@ -1127,7 +1136,7 @@ def _remove_detached_nodes_from_onnx_graph(cls, onnx_graph: onnx.GraphProto): for node in onnx_graph.node: for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._remove_detached_nodes_from_onnx_graph(attribute.g) @classmethod @@ -1213,7 +1222,7 @@ def _populate_start_and_end_marker_maps(start_marker_map: Dict[str, onnx.NodePro end_marker_map[identifier].append(node) for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': for subnode in getattr(attribute, 'g').node: OnnxSaver._populate_start_and_end_marker_maps(start_marker_map, end_marker_map, subnode) @@ -1381,7 +1390,7 @@ def _set_output_names(onnx_model: onnx.ModelProto, graphs_list: List[onnx.GraphP for node in onnx_model.graph.node: for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._set_output_names_for_graph(attribute.g, graphs_list, output_names_for_all_graphs, map_output_tensor_to_node, map_input_tensor_to_node) @@ -1452,7 +1461,7 @@ def _get_all_nodes(onnx_graph: onnx.GraphProto, all_nodes: Union[List[onnx.NodeP for node in onnx_graph.node: all_nodes.append(node) for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._get_all_nodes(attribute.g, all_nodes) return all_nodes @@ -1519,7 +1528,7 @@ def _get_all_initializers(onnx_graph: onnx.GraphProto, initializers: Union[List[ initializers.append(initializer) for node in onnx_graph.node: for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._get_all_initializers(attribute.g, initializers) return initializers @@ -1547,7 +1556,7 @@ def _populate_node_to_io_tensor_map_and_valid_param_set( valid_param_set.add(input_tensor) for attribute in node.attribute: - if getattr(attribute, 'g', None) is not None: + if getattr(attribute, 'g').name != '': OnnxSaver._populate_node_to_io_tensor_map_and_valid_param_set(attribute.g, initializer_names, node_to_io_tensor_name_map, valid_param_set) diff --git a/TrainingExtensions/torch/test/python/test_onnx_utils.py b/TrainingExtensions/torch/test/python/test_onnx_utils.py index 2f2e259c8b..ac5f085d62 100644 --- a/TrainingExtensions/torch/test/python/test_onnx_utils.py +++ b/TrainingExtensions/torch/test/python/test_onnx_utils.py @@ -34,6 +34,7 @@ # # @@-COPYRIGHT-END-@@ # ============================================================================= + import contextlib import copy import logging @@ -1089,3 +1090,41 @@ def test_get_unique_node_output_name(self): assert ( param_name_to_updated_name[param_name] == '/down_blocks.0/Add_1/Add_output_0_dup1' ) + + def test_node_names(self): + """ Check if the 'marked_module' string is removed correctly """ + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv1(x) + + + pt_model = Model().eval() + dummy_input = torch.randn(1, 3, 24, 24) + with tempfile.TemporaryDirectory() as tmp_dir: + torch.onnx.export(pt_model.eval(), + dummy_input, + os.path.join(tmp_dir, "model.onnx"), + training=torch.onnx.TrainingMode.EVAL, + export_params=True, + input_names=['input'], + output_names=['output']) + model = onnx.load_model(os.path.join(tmp_dir, "model.onnx")) + + # Add 'marked_module' string in input and output field of onnx.GraphProto object. + model.graph.input[0].name = model.graph.input[0].name + '/marked_module' + model.graph.output[0].name = model.graph.input[0].name + '/marked_module' + + # An exception should be raised since the node input/output names are not consistent. + from onnx.checker import ValidationError + with pytest.raises(ValidationError): + onnx.checker.check_model(model) + + # Remove the 'marked_module' string + OnnxSaver._remove_marked_module_string_from_node_inp_out_names(model.graph) + + # model should be consistent. + onnx.checker.check_model(model)