Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use huggingface_hub helper function to split state dict #31091

Merged
merged 7 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
"fugashi>=1.0",
"GitPython<3.1.19",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.0,<1.0",
"huggingface-hub>=0.23.2,<1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"fugashi": "fugashi>=1.0",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.0,<1.0",
"huggingface-hub": "huggingface-hub>=0.23.2,<1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
Expand Down
24 changes: 20 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from zipfile import is_zipfile

import torch
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, Identity
Expand Down Expand Up @@ -358,6 +359,10 @@ def shard_checkpoint(
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
"""
logger.warning(
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
"split_torch_state_dict_into_shards from huggingface_hub library"
)
max_shard_size = convert_file_size_to_int(max_shard_size)

sharded_state_dicts = [{}]
Expand Down Expand Up @@ -2585,7 +2590,17 @@ def save_pretrained(
else:
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME

shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}

# Clean the folder from a previous save
for filename in os.listdir(save_directory):
Expand All @@ -2601,14 +2616,15 @@ def save_pretrained(
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in shards.keys()
and filename not in state_dict_split.filename_to_tensors.keys()
and is_main_process
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

# Save the model
for shard_file, shard in shards.items():
for shard_file, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
Expand All @@ -2628,7 +2644,7 @@ def save_pretrained(
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)

Expand Down
Loading