Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback for saving trainable parameters and model config #178

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier
from save_callback import SavePeftModelCallback


from datasets import load_dataset
Expand Down Expand Up @@ -202,6 +203,7 @@ def train():
train_dataset=dataset["train"],
eval_dataset=None,
data_collator=data_collator)
trainer.add_callback(SavePeftModelCallback)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
Expand Down
9 changes: 9 additions & 0 deletions merge_lora_weights_and_save_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,18 @@ def main(args):
print("base model", args.base_model)
print("peft model", args.peft_model)

# Load config from peft model dir if exists
# In order to reuse the rope scaling configurations
config_path = os.path.join(args.peft_model, "config.json")
if os.path.isfile(config_path):
config = transformers.AutoConfig.from_pretrained(config_path)
else:
config = transformers.AutoConfig.from_pretrained(args.base_model)

# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
args.base_model,
config=config,
cache_dir=args.cache_dir,
torch_dtype=torch.float16,
device_map="auto",
Expand Down
48 changes: 48 additions & 0 deletions save_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import logging
import torch

from transformers import (
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl,
)

PREFIX_CHECKPOINT_DIR = "step"

class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
os.makedirs(checkpoint_folder, exist_ok=True)

modules_to_save = []
for module_name in args.trainable_params.split(","):
if len(module_name.strip()) > 0:
modules_to_save.append(module_name)

# Save trainable parameters if exist
if modules_to_save:
state_dict = kwargs["model"].state_dict()
to_save = {}
for key, value in state_dict.items():
if any(module_name in key for module_name in modules_to_save):
to_save[key.replace("base_model.model.", "")] = value
torch.save(to_save, os.path.join(checkpoint_folder, "trainable_params.bin"))
logging.info(f"Trainable parameters saved at: {checkpoint_folder}")

# Save LoRA adapter weight
kwargs["model"].config.save_pretrained(checkpoint_folder)
logging.info(f"LoRA adapter weights saved at: {checkpoint_folder}")

# Save model config in order to reuse rope scaling settings
kwargs["model"].save_pretrained(checkpoint_folder)
logging.info(f"Model config saved at: {checkpoint_folder}")

return control
2 changes: 2 additions & 0 deletions supervised-fine-tune-qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier
from save_callback import SavePeftModelCallback

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
Expand Down Expand Up @@ -350,6 +351,7 @@ def forward(self, x):
model.gradient_checkpointing_enable() # enable gradient checkpointing

trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.add_callback(SavePeftModelCallback)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
Expand Down
2 changes: 2 additions & 0 deletions supervised-fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier
from save_callback import SavePeftModelCallback

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
Expand Down Expand Up @@ -316,6 +317,7 @@ def train():
model.gradient_checkpointing_enable() # enable gradient checkpointing

trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.add_callback(SavePeftModelCallback)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
Expand Down