Skip to content

Commit

Permalink
Push sharded checkpoint to hub when push_to_hub=True in `TrainingAr…
Browse files Browse the repository at this point in the history
…guments` (huggingface#31808)

Save sharded checkpoint in Trainer
  • Loading branch information
SunMarc authored and amyeroberts committed Jul 19, 2024
1 parent 1d34ea8 commit 031d12c
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import glob
import importlib.metadata
import inspect
import json
import math
import os
import random
Expand Down Expand Up @@ -4215,6 +4216,15 @@ def _push_from_checkpoint(self, checkpoint_folder):
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
# Add sharded checkpoints if we have an index
for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
index_path = os.path.join(checkpoint_folder, index_file)
if os.path.isfile(index_path):
modeling_files.append(index_file)
with open(index_path) as f:
index = json.loads(f.read())
shard_files = list(set(index["weight_map"].values()))
modeling_files.extend(shard_files)
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
Expand Down

0 comments on commit 031d12c

Please sign in to comment.