diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index add69bbf3993b3..199562f722de39 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -48,7 +48,7 @@ import numpy as np import torch import torch.distributed as dist -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler @@ -3489,6 +3489,12 @@ def create_model_card( if not self.is_world_process_zero(): return + model_card_filepath = os.path.join(self.args.output_dir, "README.md") + is_peft_library = False + if os.path.exists(model_card_filepath): + library_name = ModelCard.load(model_card_filepath).data.get("library_name") + is_peft_library = library_name == "peft" + training_summary = TrainingSummary.from_trainer( self, language=language, @@ -3502,9 +3508,12 @@ def create_model_card( dataset_args=dataset_args, ) model_card = training_summary.to_model_card() - with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: + with open(model_card_filepath, "w") as f: f.write(model_card) + if is_peft_library: + unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: