Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch tied weights being duplicated in the exported ONNX models #1326

Merged
merged 5 commits into from
Sep 1, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Aug 31, 2023

This is the case for example when the embedding weight and language modeling head share the same weight.

Although PyTorch has a pass to deduplicate identical initializers (https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/passes/onnx/deduplicate_initializers.cpp), one issue arises when torch.onnx.export does constant folding, that may e.g. transpose a weight (in the nn.Linear case), that later messes up with the initializer deduplication. See details: pytorch/pytorch#108342

For small models, the duplication can result in ONNX weights >30% larger vs pytorch weight. The issue is less severe for large models. Given that doing constant folding is generally good, just passing do_constant_folding=False is not really a satisfying solution hence this PR.

e.g. bloom-560m goes from 3.0 GiB to 2.1 GiB.

Fixes pytorch/pytorch#108342 as a post-processing step for nn.Embedding & nn.Linear.

@xenova
Copy link
Contributor

xenova commented Aug 31, 2023

Awesome! 🚀
I've tested and bloom-560m does go down to 2.1 GiB! However, after quantizing (with this script), the quantized model is around 1.5GiB 😅

@xenova
Copy link
Contributor

xenova commented Aug 31, 2023

I also tried using optimum-cli directly

optimum-cli onnxruntime quantize --avx512 --onnx_model bloom_onnx --output ./quantized_bloom_onnx

but same results (1.48GiB in size)

@fxmarty
Copy link
Contributor Author

fxmarty commented Sep 1, 2023

@xenova This should be fixed. Could you give a second try? I'm now getting a quantized bloom-560m that is 536 MiB.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot @fxmarty !

optimum/exporters/onnx/base.py Outdated Show resolved Hide resolved
Comment on lines +523 to +524
if is_torch_available() and isinstance(models_and_onnx_configs[first_key][0], nn.Module):
if is_accelerate_available():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:import torch.nn here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it rather as the top of the file under if is_torch_available().

optimum/onnx/graph_transformations.py Show resolved Hide resolved
Comment on lines +140 to +146
if save_path:
# Overwrite.
save_path = str(save_path)
if save_path.endswith(".onnx") and os.path.isfile(save_path):
os.remove(save_path)
save_path = Path(save_path).as_posix()
onnx.save(model, save_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no saving is done if the model is small enough and we dont provide a save_path, only the check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Comment on lines +147 to +171
elif save_path is not None:
# path/to/model.onnx
save_path = Path(save_path).as_posix()

external_file_name = os.path.basename(save_path) + "_data"
# path/to/model.onnx_data
external_path = os.path.join(os.path.dirname(save_path), external_file_name)
if save_path.endswith(".onnx") and os.path.isfile(save_path):
os.remove(save_path)
if os.path.isfile(external_path):
os.remove(external_path)
onnx.save(
model,
save_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_file_name,
)
try:
onnx.checker.check_model(save_path)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case we save the model first, meaning that it is possible that an ONNX file is saved with the check failing afterwards?
It's the opposite of the first case so just checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but that is intended. onnx checker can either take a ModelProto or a path, and in case of a model larger than 2 GB, a path needs to be used.

def _get_weights_to_tie(tied_params, torch_model: "nn.Module") -> Tuple[List[List[str]]]:
"""
Find tied weights in the PyTorch model `model`, and separate them in tied weights
for which an untying strategy exists and do not exist.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for which an untying strategy exists and do not exist.
for which an untying strategy exists and does not exist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I did not really understand the docstring.

Copy link
Contributor Author

@fxmarty fxmarty Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation can only re-tie Gather and MatMul weights on the ONNX side, while there may be other tied weights on the torch side. We use the pytorch model to split the tied weights groups in one for which we have an available implementation to tie on the ONNX side, and an other for which we don't have (we will detect whether weights are duplicated, but won't fix). I updated the docstring

optimum/onnx/transformations_utils.py Outdated Show resolved Hide resolved
for params in tied_params:
skip_group = False
for param_name in params:
module_name = ".".join(param_name.split(".")[:-1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
module_name = ".".join(param_name.split(".")[:-1])
module_name = param_name.rsplit(".", maxsplit=1)[0]

Not sure it is more clear though

Copy link
Contributor

@xenova xenova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm that 1a5dbfc fixed it and the quantized model is now of the correct size! 🥳

Just some minor suggestions in the review:

optimum/exporters/onnx/base.py Outdated Show resolved Hide resolved
_remove_redundant_initializers,
_replace_input_names,
_unify_onnx_outputs,
cast_int64_tensorproto_to_int32,
)


if TYPE_CHECKING:
import torch.nn as nn
Copy link
Contributor

@xenova xenova Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be under a is_torch_available(), or can we assume it's available here? I see this in other places, so it's most likely not necessary - just wanted to check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably assume that a dev has all the dependencies installed so probably not needed.

# If not found (e.g. "lm_head.weight"), we greedily search for all initializers from potentially matching node names (e.g. "lm_head").
# This greedy approach may found more initializers than wanted.
if not identical_initializer:
module_name = ".".join(param_name.split(".")[:-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty fxmarty merged commit b14379e into huggingface:main Sep 1, 2023
64 of 68 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNX export constant folding messes up with shared weight deduplication
3 participants