From 5324bf9c07c318015eccc5fba370a81368a8df28 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 7 Dec 2023 17:36:02 +0530 Subject: [PATCH] update `create_model_card` to properly save peft details when using Trainer with PEFT (#27754) * update `create_model_card` to properly save peft details when using Trainer with PEFT * nit * Apply suggestions from code review Co-authored-by: Benjamin Bossan --------- Co-authored-by: Benjamin Bossan --- src/transformers/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3c9e4420124012..422be2247bd638 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -48,7 +48,7 @@ import numpy as np import torch import torch.distributed as dist -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler @@ -3494,6 +3494,12 @@ def create_model_card( if not self.is_world_process_zero(): return + model_card_filepath = os.path.join(self.args.output_dir, "README.md") + is_peft_library = False + if os.path.exists(model_card_filepath): + library_name = ModelCard.load(model_card_filepath).data.get("library_name") + is_peft_library = library_name == "peft" + training_summary = TrainingSummary.from_trainer( self, language=language, @@ -3507,9 +3513,12 @@ def create_model_card( dataset_args=dataset_args, ) model_card = training_summary.to_model_card() - with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: + with open(model_card_filepath, "w") as f: f.write(model_card) + if is_peft_library: + unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: