-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PyTorch tied weights being duplicated in the exported ONNX models #1326
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,25 +14,32 @@ | |
import copy | ||
import os | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
from typing import TYPE_CHECKING, List, Optional, Union | ||
|
||
import onnx | ||
from onnx import ModelProto | ||
|
||
from ..utils import logging | ||
from .transformations_utils import ( | ||
_create_name_sharing_dict, | ||
_deduplicate_gather_matmul, | ||
_deduplicated_cross_model_initializers, | ||
_find_duplicate_initializers, | ||
_find_matching_initializers, | ||
_get_all_inputs, | ||
_get_onnx_opset, | ||
_get_weights_to_tie, | ||
_remove_redundant_initializers, | ||
_replace_input_names, | ||
_unify_onnx_outputs, | ||
cast_int64_tensorproto_to_int32, | ||
) | ||
|
||
|
||
if TYPE_CHECKING: | ||
import torch.nn as nn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this also be under a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can probably assume that a dev has all the dependencies installed so probably not needed. |
||
|
||
|
||
logger = logging.get_logger() | ||
|
||
|
||
|
@@ -41,6 +48,8 @@ def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelP | |
Finds and removes duplicate weights in a model by keeping only unique weights, and make the duplicate values point | ||
to them. | ||
|
||
This function only removes duplicate weights that are exactly identical (e.g., not transposed). | ||
|
||
Args: | ||
model (`onnx.ModelProto`): The model to remove duplicates from. | ||
inplace (`bool`, defaults to False): Whether to perform this transformation inplace. | ||
|
@@ -59,6 +68,35 @@ def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelP | |
return model | ||
|
||
|
||
def remove_duplicate_weights_from_tied_info( | ||
onnx_model: ModelProto, torch_model: "nn.Module", tied_params: List[List[str]], save_path: str | ||
): | ||
""" | ||
Tries to remove potential duplicate ONNX initializers from the tied information in tied_params. | ||
|
||
Args: | ||
onnx_model (`onnx.ModelProto`): | ||
The ONNX model for which to tie potentially duplicate initializers. | ||
fxmarty marked this conversation as resolved.
Show resolved
Hide resolved
|
||
torch_model (`nn.Module`): | ||
The PyTorch model corresponding to the ONNX one. | ||
tied_params (`List[List[str]]`): | ||
A list of groups of torch parameters that are tied, i.e. shared. For them, | ||
the torch module shares the same pointer. | ||
""" | ||
tied_groups_to_tie, tied_groups_ignored = _get_weights_to_tie(tied_params, torch_model) | ||
|
||
initializer_name_to_idx = {} | ||
for idx, initializer in enumerate(onnx_model.graph.initializer): | ||
initializer_name_to_idx[initializer.name] = idx | ||
|
||
tied_groups_map = _find_matching_initializers(tied_params, onnx_model, initializer_name_to_idx) | ||
|
||
onnx_model = _deduplicate_gather_matmul(onnx_model, tied_groups_to_tie, tied_groups_map, initializer_name_to_idx) | ||
check_and_save_model(onnx_model, save_path=save_path) | ||
|
||
return onnx_model | ||
|
||
|
||
def replace_atenops_to_gather(model: ModelProto) -> ModelProto: | ||
""" | ||
Replaces broken ATenOp nodes back to Gather nodes. | ||
|
@@ -89,6 +127,54 @@ def replace_atenops_to_gather(model: ModelProto) -> ModelProto: | |
return model | ||
|
||
|
||
def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]): | ||
# for large models, a path must be provided instead of a ModelProto: | ||
# https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb | ||
if model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF: | ||
# For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768 | ||
try: | ||
onnx.checker.check_model(model) | ||
except Exception as e: | ||
if "No Op registered for" in str(e): | ||
pass | ||
else: | ||
raise e | ||
if save_path: | ||
# Overwrite. | ||
save_path = str(save_path) | ||
if save_path.endswith(".onnx") and os.path.isfile(save_path): | ||
os.remove(save_path) | ||
save_path = Path(save_path).as_posix() | ||
onnx.save(model, save_path) | ||
Comment on lines
+142
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So no saving is done if the model is small enough and we dont provide a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
elif save_path is not None: | ||
# path/to/model.onnx | ||
save_path = Path(save_path).as_posix() | ||
|
||
external_file_name = os.path.basename(save_path) + "_data" | ||
# path/to/model.onnx_data | ||
external_path = os.path.join(os.path.dirname(save_path), external_file_name) | ||
if save_path.endswith(".onnx") and os.path.isfile(save_path): | ||
os.remove(save_path) | ||
if os.path.isfile(external_path): | ||
os.remove(external_path) | ||
onnx.save( | ||
model, | ||
save_path, | ||
save_as_external_data=True, | ||
all_tensors_to_one_file=True, | ||
location=external_file_name, | ||
) | ||
try: | ||
onnx.checker.check_model(save_path) | ||
except Exception as e: | ||
if "No Op registered for" in str(e): | ||
pass | ||
else: | ||
raise e | ||
Comment on lines
+149
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case we save the model first, meaning that it is possible that an ONNX file is saved with the check failing afterwards? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes but that is intended. onnx checker can either take a ModelProto or a path, and in case of a model larger than 2 GB, a path needs to be used. |
||
else: | ||
logger.info("Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given.") | ||
|
||
|
||
def merge_decoders( | ||
decoder: Union[ModelProto, Path, str], | ||
decoder_with_past: Union[ModelProto, Path, str], | ||
|
@@ -209,38 +295,7 @@ def merge_decoders( | |
|
||
merged_model = onnx.helper.make_model(merged_graph, producer_name=producer_name, opset_imports=opset_imports) | ||
|
||
# for large models, a path must be provided instead of a ModelProto: | ||
# https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb | ||
if merged_model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF: | ||
# For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768 | ||
try: | ||
onnx.checker.check_model(merged_model) | ||
except Exception as e: | ||
if "No Op registered for" in str(e): | ||
pass | ||
else: | ||
raise e | ||
if save_path: | ||
save_path = Path(save_path).as_posix() | ||
onnx.save(merged_model, save_path) | ||
elif save_path is not None: | ||
save_path = Path(save_path).as_posix() | ||
onnx.save( | ||
merged_model, | ||
save_path, | ||
save_as_external_data=True, | ||
all_tensors_to_one_file=True, | ||
location=os.path.basename(save_path) + "_data", | ||
) | ||
try: | ||
onnx.checker.check_model(save_path) | ||
except Exception as e: | ||
if "No Op registered for" in str(e): | ||
pass | ||
else: | ||
raise e | ||
else: | ||
logger.info("Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given.") | ||
check_and_save_model(merged_model, save_path=save_path) | ||
|
||
return merged_model | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:import torch.nn here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put it rather as the top of the file under
if is_torch_available()
.