Skip to content

Commit

Permalink
add transpose rather on MatMul node and not Gather
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Sep 1, 2023
1 parent 155cc8b commit 1a5dbfc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 17 deletions.
6 changes: 2 additions & 4 deletions optimum/onnx/graph_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from ..utils import logging
from .transformations_utils import (
_create_name_sharing_dict,
_deduplicate_gather_matmul,
_deduplicated_cross_model_initializers,
_find_duplicate_initializers,
_find_matching_initializers,
_get_all_inputs,
_get_onnx_opset,
_get_weights_to_tie,
_remove_duplicate_initializers,
_remove_redundant_initializers,
_replace_input_names,
_unify_onnx_outputs,
Expand Down Expand Up @@ -89,9 +89,7 @@ def remove_duplicate_weights_from_tied_info(

tied_groups_map = _find_matching_initializers(tied_params, onnx_model, initializer_name_to_idx)

onnx_model = _remove_duplicate_initializers(
onnx_model, tied_groups_to_tie, tied_groups_map, initializer_name_to_idx
)
onnx_model = _deduplicate_gather_matmul(onnx_model, tied_groups_to_tie, tied_groups_map, initializer_name_to_idx)
check_and_save_model(onnx_model, save_path=save_path)

return onnx_model
Expand Down
69 changes: 56 additions & 13 deletions optimum/onnx/transformations_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def _get_weights_to_tie(tied_params, torch_model: "nn.Module") -> Tuple[List[Lis
module_name = ".".join(param_name.split(".")[:-1])

module = recurse_getattr(torch_model, module_name)

if module.__class__.__name__ not in SUPPORTED_DEDUPLICATION_OPS:
skip_group = True
if len(params) != 2:
Expand Down Expand Up @@ -374,32 +373,53 @@ def _find_matching_initializers(
# exactly matching initializer name.
identical_initializer = False
if param_name in initializer_name_to_idx.keys():
torch_to_initializer.append({"param_name": param_name, "initializer_name": {param_name}})
nodes_containing_initializer = set()
for node in model.graph.node:
if param_name in node.input:
nodes_containing_initializer.add(node.name)

torch_to_initializer.append(
{
"param_name": param_name,
"initializer_name": {param_name},
"nodes_containing_initializer": nodes_containing_initializer,
}
)
identical_initializer = True

# If not found (e.g. "lm_head.weight"), we greedily search for all initializers from potentially matching node names (e.g. "lm_head").
# This greedy approach may found more initializers than wanted.
if not identical_initializer:
module_name = ".".join(param_name.split(".")[:-1])
candidate_inputs = []
candidate_inputs = {}
candidate_node_idxs = []
for i, node in enumerate(model.graph.node):
if module_name in node.name:
candidate_node_idxs.append(i)

for node_idx in candidate_node_idxs:
candidate_inputs.extend(model.graph.node[node_idx].input)
node_name = model.graph.node[node_idx].name
candidate_inputs[node_name] = list(model.graph.node[node_idx].input)
torch_to_initializer_param = set()
for input_name in candidate_inputs:
if input_name in initializer_name_to_idx.keys():
torch_to_initializer_param.add(input_name)
nodes_containing_initializer = set()
for node_name, input_names in candidate_inputs.items():
for input_name in input_names:
if input_name in initializer_name_to_idx.keys():
torch_to_initializer_param.add(input_name)
nodes_containing_initializer.add(node_name)

if len(torch_to_initializer_param) == 0:
logger.warning(
f"Could not find ONNX initializer for torch parameter {param_name}. {param_name} will not be checked for deduplication."
)

torch_to_initializer.append({"param_name": param_name, "initializer_name": torch_to_initializer_param})
torch_to_initializer.append(
{
"param_name": param_name,
"initializer_name": torch_to_initializer_param,
"nodes_containing_initializer": nodes_containing_initializer,
}
)

intersect = torch_to_initializer[0]["initializer_name"]
for i in range(1, len(params)):
Expand All @@ -420,26 +440,49 @@ def _find_matching_initializers(
return tied_groups_map


def _remove_duplicate_initializers(
def _deduplicate_gather_matmul(
model: ModelProto,
tied_groups_to_tie: List[List[str]],
tied_groups_map: Dict[Tuple[str], List[Dict[str, Any]]],
initializer_name_to_idx: Dict[str, int],
):
"""
Removes the duplicate initializers from the ONNX model based on the information in tied_groups_map i.e. of which ONNX initializers correspond to a single torch parameter.
Removes the duplicate initializers for Gather and MatMul from the ONNX model based on the information in tied_groups_map i.e. of which ONNX initializers correspond to a single torch parameter.
"""
node_name_to_idx = {}
for idx, node in enumerate(model.graph.node):
node_name_to_idx[node.name] = idx

for params in tied_groups_to_tie:
torch_to_initializer = tied_groups_map[tuple(params)]

# We use the first index initializer as the reference for deduplication
ref_initializer_name = next(iter(torch_to_initializer[0]["initializer_name"]))
# ONNX Runtime quantization behaves bad with Transpose -> Gather. Thus, we take as reference the Gather node, and rather edit MatMul nodes.
ref_idx = None
for i in range(len(torch_to_initializer)):
ops_using_initializer = set()
for node_name in torch_to_initializer[i]["nodes_containing_initializer"]:
ops_using_initializer.add(model.graph.node[node_name_to_idx[node_name]].op_type)

if ops_using_initializer == {"Gather"}:
ref_idx = i
break

if ref_idx is None:
logger.warning(
f"Could not deduplicate initializers corresponding to the torch tied parameters {params} as an initializer used only by Gather nodes could not found. Skipping deduplication."
)
continue

ref_initializer_name = next(iter(torch_to_initializer[ref_idx]["initializer_name"]))
ref_initializer_idx = initializer_name_to_idx[ref_initializer_name]
ref_initializer = model.graph.initializer[ref_initializer_idx]
ref_type = ref_initializer.data_type
ref_data = numpy_helper.to_array(ref_initializer)

for i in range(1, len(torch_to_initializer)):
for i in range(len(torch_to_initializer)):
if i == ref_idx:
continue

initializer_name = next(iter(torch_to_initializer[i]["initializer_name"]))
initializer_idx = initializer_name_to_idx[initializer_name]
initializer = model.graph.initializer[initializer_idx]
Expand Down

0 comments on commit 1a5dbfc

Please sign in to comment.