Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch tied weights being duplicated in the exported ONNX models #1326

Merged
merged 5 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def main_export(
f" referring to `optimum.exporters.tasks.TaskManager`'s `_TASKS_TO_AUTOMODELS`."
)

onnx_files_subpaths = None
onnx_files_subpaths = [key + ".onnx" for key in models_and_onnx_configs.keys()]
else:
# save the subcomponent configuration
for model_name in models_and_onnx_configs:
Expand Down Expand Up @@ -488,8 +488,6 @@ def main_export(
if optimize is not None:
from ...onnxruntime import AutoOptimizationConfig, ORTOptimizer

if onnx_files_subpaths is None:
onnx_files_subpaths = [key + ".onnx" for key in models_and_onnx_configs.keys()]
optimizer = ORTOptimizer.from_pretrained(output, file_names=onnx_files_subpaths)

optimization_config = AutoOptimizationConfig.with_optimization_level(optimization_level=optimize)
Expand Down
54 changes: 40 additions & 14 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,22 @@
import gc
import inspect
import itertools
import os
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from transformers.utils import is_torch_available
import onnx
from transformers.utils import is_accelerate_available, is_torch_available

from ...onnx import remove_duplicate_weights_from_tied_info


if is_torch_available():
import torch.nn as nn

from ...onnx import merge_decoders
from ...utils import (
Expand All @@ -42,10 +50,13 @@
from ...utils.doc import add_dynamic_docstring
from ...utils.import_utils import check_if_transformers_greater, is_onnx_available, is_onnxruntime_available
from ..base import ExportConfig
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import ModelPatcher, Seq2SeqModelPatcher


if is_accelerate_available():
from accelerate.utils import find_tied_parameters

if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel

Expand Down Expand Up @@ -505,9 +516,27 @@ def post_process_exported_models(
models_and_onnx_configs (`Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]]`):
A dictionnary containing the models t apply post-processing on, and their corresponding ONNX configuration.
onnx_files_subpaths (`List[str]`):
The relative paths from the export directory to the ONNX files to do post-processing on. The order must be the same as*
the order of submodels in the ordered dict `models_and_onnx_configs`.
The relative paths from the export directory to the ONNX files to do post-processing on. The order must be the same as
the order of submodels in the ordered dict `models_and_onnx_configs`.
"""
first_key = next(iter(models_and_onnx_configs))
if is_torch_available() and isinstance(models_and_onnx_configs[first_key][0], nn.Module):
if is_accelerate_available():
Comment on lines +523 to +524
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:import torch.nn here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it rather as the top of the file under if is_torch_available().

logger.info("Deduplicating shared (tied) weights...")
keys = list(models_and_onnx_configs.keys())
for i, subpath in enumerate(onnx_files_subpaths):
onnx_model = onnx.load(os.path.join(path, subpath))

torch_model = models_and_onnx_configs[keys[i]][0]
tied_params = find_tied_parameters(torch_model)
remove_duplicate_weights_from_tied_info(
onnx_model, torch_model, tied_params, save_path=os.path.join(path, subpath)
)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
else:
logger.warning(
"Weight deduplication check in the ONNX export requires accelerate. Please install accelerate to run it."
)

return models_and_onnx_configs, onnx_files_subpaths


Expand Down Expand Up @@ -918,14 +947,14 @@ def post_process_exported_models(
],
onnx_files_subpaths: List[str],
):
models_and_onnx_configs, onnx_files_subpaths = super().post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder was exported without/with past
if self.use_past is True and len(models_and_onnx_configs) == 3:
if onnx_files_subpaths is not None:
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
else:
decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
# The decoder with past does not output the cross attention past key values as they are constant,
Expand All @@ -940,10 +969,7 @@ def post_process_exported_models(
raise Exception(f"Unable to merge decoders. Detailed error: {e}")

# In order to do the validation of the two branches on the same file
if onnx_files_subpaths is not None:
encoder_path = onnx_files_subpaths[0]
else:
encoder_path = ONNX_ENCODER_NAME + ".onnx"
encoder_path = onnx_files_subpaths[0]

onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]

Expand Down
12 changes: 6 additions & 6 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def post_process_exported_models(
],
onnx_files_subpaths: List[str],
):
models_and_onnx_configs, onnx_files_subpaths = super().post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder-only was exported separately without/with past
if self.use_past is True and len(models_and_onnx_configs) == 2:
if onnx_files_subpaths is not None:
decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
else:
decoder_path = Path(path, ONNX_DECODER_NAME + ".onnx")
decoder_with_past_path = Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_path = Path(path, onnx_files_subpaths[0])
decoder_with_past_path = Path(path, onnx_files_subpaths[1])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
merge_decoders(
Expand Down
2 changes: 2 additions & 0 deletions optimum/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"merge_decoders",
"remove_duplicate_weights",
"replace_atenops_to_gather",
"remove_duplicate_weights_from_tied_info",
],
}

Expand All @@ -30,6 +31,7 @@
cast_slice_nodes_inputs_to_int32,
merge_decoders,
remove_duplicate_weights,
remove_duplicate_weights_from_tied_info,
replace_atenops_to_gather,
)
else:
Expand Down
121 changes: 88 additions & 33 deletions optimum/onnx/graph_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,32 @@
import copy
import os
from pathlib import Path
from typing import Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import onnx
from onnx import ModelProto

from ..utils import logging
from .transformations_utils import (
_create_name_sharing_dict,
_deduplicate_gather_matmul,
_deduplicated_cross_model_initializers,
_find_duplicate_initializers,
_find_matching_initializers,
_get_all_inputs,
_get_onnx_opset,
_get_weights_to_tie,
_remove_redundant_initializers,
_replace_input_names,
_unify_onnx_outputs,
cast_int64_tensorproto_to_int32,
)


if TYPE_CHECKING:
import torch.nn as nn
Copy link
Contributor

@xenova xenova Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be under a is_torch_available(), or can we assume it's available here? I see this in other places, so it's most likely not necessary - just wanted to check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably assume that a dev has all the dependencies installed so probably not needed.



logger = logging.get_logger()


Expand All @@ -41,6 +48,8 @@ def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelP
Finds and removes duplicate weights in a model by keeping only unique weights, and make the duplicate values point
to them.

This function only removes duplicate weights that are exactly identical (e.g., not transposed).

Args:
model (`onnx.ModelProto`): The model to remove duplicates from.
inplace (`bool`, defaults to False): Whether to perform this transformation inplace.
Expand All @@ -59,6 +68,35 @@ def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelP
return model


def remove_duplicate_weights_from_tied_info(
onnx_model: ModelProto, torch_model: "nn.Module", tied_params: List[List[str]], save_path: str
):
"""
Tries to remove potential duplicate ONNX initializers from the tied information in tied_params.

Args:
onnx_model (`onnx.ModelProto`):
The ONNX model for which to tie potentially duplicate initializers.
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
torch_model (`nn.Module`):
The PyTorch model corresponding to the ONNX one.
tied_params (`List[List[str]]`):
A list of groups of torch parameters that are tied, i.e. shared. For them,
the torch module shares the same pointer.
"""
tied_groups_to_tie, tied_groups_ignored = _get_weights_to_tie(tied_params, torch_model)

initializer_name_to_idx = {}
for idx, initializer in enumerate(onnx_model.graph.initializer):
initializer_name_to_idx[initializer.name] = idx

tied_groups_map = _find_matching_initializers(tied_params, onnx_model, initializer_name_to_idx)

onnx_model = _deduplicate_gather_matmul(onnx_model, tied_groups_to_tie, tied_groups_map, initializer_name_to_idx)
check_and_save_model(onnx_model, save_path=save_path)

return onnx_model


def replace_atenops_to_gather(model: ModelProto) -> ModelProto:
"""
Replaces broken ATenOp nodes back to Gather nodes.
Expand Down Expand Up @@ -89,6 +127,54 @@ def replace_atenops_to_gather(model: ModelProto) -> ModelProto:
return model


def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]):
# for large models, a path must be provided instead of a ModelProto:
# https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb
if model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
# For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768
try:
onnx.checker.check_model(model)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
if save_path:
# Overwrite.
save_path = str(save_path)
if save_path.endswith(".onnx") and os.path.isfile(save_path):
os.remove(save_path)
save_path = Path(save_path).as_posix()
onnx.save(model, save_path)
Comment on lines +142 to +148
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no saving is done if the model is small enough and we dont provide a save_path, only the check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

elif save_path is not None:
# path/to/model.onnx
save_path = Path(save_path).as_posix()

external_file_name = os.path.basename(save_path) + "_data"
# path/to/model.onnx_data
external_path = os.path.join(os.path.dirname(save_path), external_file_name)
if save_path.endswith(".onnx") and os.path.isfile(save_path):
os.remove(save_path)
if os.path.isfile(external_path):
os.remove(external_path)
onnx.save(
model,
save_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_file_name,
)
try:
onnx.checker.check_model(save_path)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
Comment on lines +149 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we save the model first, meaning that it is possible that an ONNX file is saved with the check failing afterwards?
It's the opposite of the first case so just checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but that is intended. onnx checker can either take a ModelProto or a path, and in case of a model larger than 2 GB, a path needs to be used.

else:
logger.info("Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given.")


def merge_decoders(
decoder: Union[ModelProto, Path, str],
decoder_with_past: Union[ModelProto, Path, str],
Expand Down Expand Up @@ -209,38 +295,7 @@ def merge_decoders(

merged_model = onnx.helper.make_model(merged_graph, producer_name=producer_name, opset_imports=opset_imports)

# for large models, a path must be provided instead of a ModelProto:
# https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb
if merged_model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
# For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768
try:
onnx.checker.check_model(merged_model)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
if save_path:
save_path = Path(save_path).as_posix()
onnx.save(merged_model, save_path)
elif save_path is not None:
save_path = Path(save_path).as_posix()
onnx.save(
merged_model,
save_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=os.path.basename(save_path) + "_data",
)
try:
onnx.checker.check_model(save_path)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
else:
logger.info("Merged ONNX model exceeds 2GB, the model will not be checked without `save_path` given.")
check_and_save_model(merged_model, save_path=save_path)

return merged_model

Expand Down
Loading
Loading