From 87703089fa0ad60f008b7a7990f5cf3e77ccd26e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 19 Jun 2024 04:53:26 +1000 Subject: [PATCH] Ollama (#665) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update llama.py * offload * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * continued pretraining trainer * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * is_bfloat16_supported * Update __init__.py * Update README.md * Update llama.py * is_bfloat16_supported * Update __init__.py * Mistral v3 * Phi 3 medium * Update chat_templates.py * Update chat_templates.py * Phi-3 * Update save.py * Update README.md Mistral v3 to Mistral v0.3 * Untrained tokens * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update save.py * Update save.py * Update save.py * checkpoint * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * accelerate * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * train_dataloader * Update llama.py * Update llama.py * Update llama.py * use_fast_convert * Update save.py * Update save.py * Update save.py * Update save.py * remove_special_tokens * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Update chat_templates.py * Support bfloat16 GGUF * Update save.py * Update llama.py * fast_forward_inference * Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size * Update qwen2.py * docs * Update README.md * Update qwen2.py * README: Fix minor typo. (#559) * README: Fix minor typo. One-character typo fix while reading. * Update README.md --------- Co-authored-by: Daniel Han * Update mistral.py * Update qwen2.py * Update qwen2.py * Update qwen2.py * Update llama.py * Update llama.py * Update llama.py * Update README.md * FastMistralModel * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Auto check rope scaling * Update llama.py * Update llama.py * Update llama.py * GPU support * Typo * Update gemma.py * gpu * Multiple GGUF saving * Update save.py * Update save.py * check PEFT and base * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update chat_templates.py * Fix breaking bug in save.py with interpreting quantization_method as a string when saving to gguf (#651) * Nightly (#649) * Update llama.py * offload * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * continued pretraining trainer * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * is_bfloat16_supported * Update __init__.py * Update README.md * Update llama.py * is_bfloat16_supported * Update __init__.py * Mistral v3 * Phi 3 medium * Update chat_templates.py * Update chat_templates.py * Phi-3 * Update save.py * Update README.md Mistral v3 to Mistral v0.3 * Untrained tokens * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update save.py * Update save.py * Update save.py * checkpoint * Update _utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update llama.py * accelerate * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update _utils.py * Update tokenizer_utils.py * train_dataloader * Update llama.py * Update llama.py * Update llama.py * use_fast_convert * Update save.py * Update save.py * Update save.py * Update save.py * remove_special_tokens * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update llama.py * Update chat_templates.py * Support bfloat16 GGUF * Update save.py * Update llama.py * fast_forward_inference * Update mapper.py * Update loader.py * Update llama.py * Update tokenizer_utils.py * info * edits * Create chat template * Fix tokenizer * Update tokenizer_utils.py * fix case where gguf saving fails due to first_conversion dtype (#630) * Support revision parameter in FastLanguageModel.from_pretrained (#629) * support `revision` parameter * match unsloth formatting of named parameters * clears any selected_adapters before calling internal_model.save_pretrained (#609) * Update __init__.py (#602) Check for incompatible modules before importing unsloth * Fixed unsloth/tokenizer_utils.py for chat training (#604) * Add GGML saving option to Unsloth for easier Ollama model creation and testing. (#345) * Add save to llama.cpp GGML to save.py. * Fix conversion command and path of convert to GGML function. * Add autosaving lora to the GGML function * Create lora save function for conversion to GGML * Test fix #2 for saving lora * Test fix #3 to save the lora adapters to convert to GGML * Remove unwated tokenizer saving for conversion to ggml and added a few print statements. * Needed tokenizer for saving, added it back, also made it more unslothy style by having positional arguments, and added a few messages. * Positional arguments didn't work out, so reverted to older version of the code, and added a few comments. * Test fix 1 for arch * Test fix 2 new Mistral error. * Test fix 3 * Revert to old version for testing. * Upload issue test fix 1 * Fix 2 uploading ggml * Positional ags added. * Temporray remove positional args * Fix upload again!!! * Add print statements and fix link * Make the calling name better * Create local saving for GGML * Add choosing directory to save local GGML. * Fix lil variable error in the save_to_custom_dir func * docs: Add LoraConfig parameters documentation (#619) * llama.cpp failing (#371) llama.cpp is failing to generate quantize versions for the trained models. Error: ```bash You might have to compile llama.cpp yourself, then run this again. You do not need to close this Python program. Run the following commands in a new terminal: You must run this in the same folder as you're saving your model. git clone https://github.com/ggerganov/llama.cpp cd llama.cpp && make clean && LLAMA_CUDA=1 make all -j Once that's done, redo the quantization. ``` But when i do clone this with recursive it works. Co-authored-by: Daniel Han * fix libcuda_dirs import for triton 3.0 (#227) * fix libcuda_dirs import for triton 3.0 * Update __init__.py * Update __init__.py --------- Co-authored-by: Daniel Han * Update save.py * Update __init__.py * Update fast_lora.py * Update save.py * Update save.py * Update save.py * Update loader.py * Update save.py * Update save.py * quantize now llama-quantize * Update chat_templates.py * Update loader.py * Update mapper.py * Update __init__.py * embedding size * Update qwen2.py * docs * Update README.md * Update qwen2.py * README: Fix minor typo. (#559) * README: Fix minor typo. One-character typo fix while reading. * Update README.md --------- Co-authored-by: Daniel Han * Update mistral.py * Update qwen2.py * Update qwen2.py * Update qwen2.py * Update llama.py * Update llama.py * Update llama.py * Update README.md * FastMistralModel * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Update mistral.py * Auto check rope scaling * Update llama.py * Update llama.py * Update llama.py * GPU support * Typo * Update gemma.py * gpu * Multiple GGUF saving * Update save.py * Update save.py * check PEFT and base * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update chat_templates.py --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén Co-authored-by: XiaoYang Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef Co-authored-by: Alberto Ferrer Co-authored-by: Thomas Viehmann Co-authored-by: Walter Korman * Fix bug in save.py with interpreting quantization_method as a string that prevents GGUF from saving * Implemented better list management and then forgot to actually call the new list variable, fixed * Check type of given quantization method and return type error if not list or string * Update save.py --------- Co-authored-by: Daniel Han Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén Co-authored-by: XiaoYang Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef Co-authored-by: Alberto Ferrer Co-authored-by: Thomas Viehmann Co-authored-by: Walter Korman * Revert "Fix breaking bug in save.py with interpreting quantization_method as …" (#652) This reverts commit 30605dec2322435eec9753c7f566a0ff610ab52c. * Revert "Revert "Fix breaking bug in save.py with interpreting quantization_me…" (#653) This reverts commit e2b2083b621208b15923595cd7f509584ff566bc. * Update llama.py * peft * patch * Update loader.py * retrain * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * Update llama.py * offload * Update llama.py * Create a starter script for command-line training to integrate in ML ops pipelines. (#623) * Update chat_templates.py * Ollama * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py * Update chat_templates.py --------- Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com> Co-authored-by: Eliot Hall <60240707+chrehall68@users.noreply.github.com> Co-authored-by: Rickard Edén Co-authored-by: XiaoYang Co-authored-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Co-authored-by: mahiatlinux <110882203+mahiatlinux@users.noreply.github.com> Co-authored-by: Sébastien De Greef Co-authored-by: Alberto Ferrer Co-authored-by: Thomas Viehmann Co-authored-by: Walter Korman Co-authored-by: ArcadaLabs-Jason <52756218+ArcadaLabs-Jason@users.noreply.github.com> --- unsloth-cli.py | 221 +++++++++++++++++++++++++ unsloth/chat_templates.py | 327 +++++++++++++++++++++++++++++++++---- unsloth/models/llama.py | 52 +++++- unsloth/tokenizer_utils.py | 3 +- 4 files changed, 569 insertions(+), 34 deletions(-) create mode 100644 unsloth-cli.py diff --git a/unsloth-cli.py b/unsloth-cli.py new file mode 100644 index 00000000..ddb0ac8b --- /dev/null +++ b/unsloth-cli.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 + +""" +🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth + +This script is designed as a starting point for fine-tuning your models using unsloth. +It includes configurable options for model loading, PEFT parameters, training arguments, +and model saving/pushing functionalities. + +You will likely want to customize this script to suit your specific use case +and requirements. + +Here are a few suggestions for customization: + - Modify the dataset loading and preprocessing steps to match your data. + - Customize the model saving and pushing configurations. + +Usage: (most of the options have valid default values this is an extended example for demonstration purposes) + python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \ + --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \ + --random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \ + --warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \ + --weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \ + --report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \ + --push_model --hub_path "hf/model" --hub_token "your_hf_token" + +To see a full list of configurable options, use: + python unsloth-cli.py --help + +Happy fine-tuning! +""" + +import argparse + +def run(args): + import torch + from unsloth import FastLanguageModel + from datasets import load_dataset + from trl import SFTTrainer + from transformers import TrainingArguments + from unsloth import is_bfloat16_supported + import logging + logging.getLogger('hf-to-gguf').setLevel(logging.WARNING) + + # Load model and tokenizer + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=args.model_name, + max_seq_length=args.max_seq_length, + dtype=args.dtype, + load_in_4bit=args.load_in_4bit, + ) + + # Configure PEFT model + model = FastLanguageModel.get_peft_model( + model, + r=args.r, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias=args.bias, + use_gradient_checkpointing=args.use_gradient_checkpointing, + random_state=args.random_state, + use_rslora=args.use_rslora, + loftq_config=args.loftq_config, + ) + + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {} + + ### Input: + {} + + ### Response: + {}""" + + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + def formatting_prompts_func(examples): + instructions = examples["instruction"] + inputs = examples["input"] + outputs = examples["output"] + texts = [] + for instruction, input, output in zip(instructions, inputs, outputs): + text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN + texts.append(text) + return {"text": texts} + + # Load and format dataset + dataset = load_dataset(args.dataset, split="train") + dataset = dataset.map(formatting_prompts_func, batched=True) + print("Data is formatted and ready!") + + # Configure training arguments + training_args = TrainingArguments( + per_device_train_batch_size=args.per_device_train_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + warmup_steps=args.warmup_steps, + max_steps=args.max_steps, + learning_rate=args.learning_rate, + fp16=not is_bfloat16_supported(), + bf16=is_bfloat16_supported(), + logging_steps=args.logging_steps, + optim=args.optim, + weight_decay=args.weight_decay, + lr_scheduler_type=args.lr_scheduler_type, + seed=args.seed, + output_dir=args.output_dir, + report_to=args.report_to, + ) + + # Initialize trainer + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=args.max_seq_length, + dataset_num_proc=2, + packing=False, + args=training_args, + ) + + # Train model + trainer_stats = trainer.train() + + # Save model + if args.save_model: + # if args.quantization_method is a list, we will save the model for each quantization method + if args.save_gguf: + if isinstance(args.quantization, list): + for quantization_method in args.quantization: + print(f"Saving model with quantization method: {quantization_method}") + model.save_pretrained_gguf( + args.save_path, + tokenizer, + quantization_method=quantization_method, + ) + if args.push_model: + model.push_to_hub_gguf( + hub_path=args.hub_path, + hub_token=args.hub_token, + quantization_method=quantization_method, + ) + else: + print(f"Saving model with quantization method: {args.quantization}") + model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization) + if args.push_model: + model.push_to_hub_gguf( + hub_path=args.hub_path, + hub_token=args.hub_token, + quantization_method=quantization_method, + ) + else: + model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) + if args.push_model: + model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token) + else: + print("Warning: The model is not saved!") + + +if __name__ == "__main__": + + # Define argument parser + parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!") + + model_group = parser.add_argument_group("🤖 Model Options") + model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load") + model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!") + model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)") + model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage") + model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training") + + lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.") + lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)") + lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)") + lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.") + lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA") + lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing") + lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.") + lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA") + lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ") + + + training_group = parser.add_argument_group("🎓 Training Options") + training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.") + training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.") + training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.") + training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.") + training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.") + training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.") + training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.") + training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.") + training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.") + + + # Report/Logging arguments + report_group = parser.add_argument_group("📊 Report Options") + report_group.add_argument('--report_to', type=str, default="tensorboard", + choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"], + help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.") + report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1") + + # Saving and pushing arguments + save_group = parser.add_argument_group('💾 Save Model Options') + save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory") + save_group.add_argument('--save_model', action='store_true', help="Save the model after training") + save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'") + save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training") + save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model") + save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+", + help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ") + + push_group = parser.add_argument_group('🚀 Push Model Options') + push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training") + push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training") + push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model") + push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub") + + args = parser.parse_args() + run(args) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index a2a02d7e..7b6da3e4 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -17,9 +17,11 @@ "test_chat_templates", "test_hf_gguf_equivalence", "remove_special_tokens", - "standardize_dataset", - "construct_chat_template", + "to_sharegpt", + "standardize_sharegpt", + "apply_chat_template", + "test_construct_chat_template", "create_ollama_modelfile", ] @@ -32,6 +34,7 @@ import shutil from .tokenizer_utils import * from .models._utils import patch_tokenizer +import re CHAT_TEMPLATES = {} @@ -713,21 +716,209 @@ def remove_special_tokens(tokenizer, prompt): pass -def standardize_dataset( +def _parse_combined_prompt(combined_prompt, dataset): + # Find {...} + possible_columns = re.findall(r"\{(.+?)\}", combined_prompt) + dataset_columns = set(dataset.column_names) + for column in possible_columns: + if column not in dataset_columns: + raise KeyError( + f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\ + f"Only allowed columns are {list(dataset_columns)}" + ) + pass + pass + + # Find [[...]] + optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE)) + optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts] + + final_optional_prompts = [] + if len(optional_prompts) != 0: + # Add left + left = optional_prompts[0] + l = left[0][0] + if l != 0: final_optional_prompts.append(combined_prompt[:l]) + + # Add in between + for left, right in zip(optional_prompts[:-1], optional_prompts[1:]): + l, r = left[0][-1], right[0][0] + final_optional_prompts.append(left) + if l != r: final_optional_prompts.append(combined_prompt[l : r]) + pass + final_optional_prompts.append(optional_prompts[-1]) + + # Add right + right = optional_prompts[-1] + r = right[0][1] + if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:]) + else: + # Just add in the entire string + final_optional_prompts.append(combined_prompt) + pass + + check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts) + assert(combined_prompt == check_combined) + + return possible_columns, final_optional_prompts +pass + + +def _create_formatter(possible_columns, final_optional_prompts, user_column_name): + # Start final prompt! + function = ["def __combined_prompt_processor__(examples):"] + columns = list(set(possible_columns)) + for column in columns: + function.append(f"{' '*4}{column}__ = examples['{column}']") + function.append(f"{' '*4}texts = []") + function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):") + + # Add optional tags as well! + final_prompt = "" + formatter = [] + + for j, optional_prompt in enumerate(final_optional_prompts): + if type(optional_prompt) is str: + columns = re.findall(r"\{(.+?)\}", optional_prompt) + formatter += columns + # Must escape \n \r + final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8") + else: + where, prompt = optional_prompt + # Strip [[...]] + # Must escape \n \r + prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8") + columns = re.findall(r"\{(.+?)\}", prompt) + x = f"__optional_{j}__" + prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if input else ''" + function.append(prompt) + formatter.append(x) + final_prompt += "{" + x + "}" + pass + pass + + function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'") + function.append(f"{' '*8}texts.append("\ + f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))") + function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }") + return "\n".join(function) +pass + + +def to_sharegpt( + dataset, + merged_prompt = "", + merged_column_name = "instruction", + output_column_name = "output", + remove_unsued_columns = True, + conversation_extension = 1, + random_state = 3407, +): + """ + Converts a dataset to ShareGPT style. + ShareGPT requires only 1 input and 1 output field. + This means one has to merge multiple columns into 1 for 1 input field. + Use `conversation_extension` to increase the length of each conversation by randomnly + selecting a few and packing them into 1. + + merged_prompt = "", Prompt to merge columns into 1 input + merged_column_name = "instruction", Final column name for the input field + output_column_name = "output", Final column name for the output field + remove_unsued_columns = True, + conversation_extension = 1, Automatically combines `conversation_extension` convos into 1 + random_state = 3407, + """ + if "conversations" in dataset.column_names: + convo = dataset[0]["conversations"] + if type(convo) is list: + raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!") + pass + pass + + possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset) + function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name) + exec(function, globals()) + dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns") + + def __convert_to_sharegpt__(examples): + users = examples[merged_column_name] + assistants = examples[output_column_name] + texts = [] + for user, assistant in zip(users, assistants): + texts.append([ + {"from" : "user", "content" : user }, + {"from" : "assistant", "content" : assistant}, + ]) + pass + return { "conversations" : texts, } + pass + + dataset = dataset.map( + __convert_to_sharegpt__, + batched = True, + desc = "Converting to ShareGPT", + # Remove unsued columns! + remove_columns = dataset.column_names if remove_unsued_columns else None, + ) + + # Randomnly concat conversations to create a long stream! + from datasets import concatenate_datasets + n_extensions = max(conversation_extension-1, 0) + if n_extensions == 0: return dataset + + dataset = dataset.rename_columns({"conversations" : f"conversations0"}) + all_shuffled = [dataset] + for j in range(1, n_extensions+1): + shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"}) + all_shuffled.append(shuffled) + pass + dataset = concatenate_datasets(all_shuffled, axis = 1) + + # Combine them into 1 + function = "def __combine_conversations__(examples):\n" + n_extensions += 1 + for j in range(n_extensions): + function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n" + function += f"{' '*4}convos = []\n" + function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\ + f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n" + function += f"{' '*8}convos.append("\ + f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n" + function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }" + + # Map function + exec(function, globals()) + dataset = dataset.map( + __combine_conversations__, + batched = True, + desc = "Extending conversations", + # Remove unsued columns! + remove_columns = dataset.column_names if remove_unsued_columns else None, + ) + return dataset +pass + + +def standardize_sharegpt( dataset, - conversation_key = "conversations", - system_message = None, aliases_for_system = ["system",], aliases_for_user = ["user", "human", "input",], aliases_for_assistant = ["gpt", "assistant", "output",], ): """ - Standardizes ShareGPT and other formats to user/assistant Hugging Face format. + Standardizes ShareGPT and other formats to user/assistant Hugging Face format. + + Get aliases for the system, user and assistant roles. + These shall map to "system", "user" and "assistant" respectively. + + aliases_for_system = ["system",], + aliases_for_user = ["user", "human", "input",], + aliases_for_assistant = ["gpt", "assistant", "output",], """ import collections import itertools - convos = dataset[:10][conversation_key] + convos = dataset[:10]["conversations"] uniques = collections.defaultdict(list) for convo in convos: for message in convo: @@ -768,24 +959,19 @@ def standardize_dataset( for x in aliases_for_assistant: aliases_mapping[x] = "assistant" def _standardize_dataset(examples): - convos = examples[conversation_key] + convos = examples["conversations"] all_convos = [] for convo in convos: - new_convo = [] - if len(convo) == 0: continue - has_system = aliases_mapping[convo[0][role_key]] == "system" - if not has_system and system_message is not None: - new_convo.append({ "role" : "system", "content" : system_message, }) - for message in convo: - role = aliases_mapping[message[role_key]] - new_convo.append({ "role" : role, "content" : message[content_key], }) - pass + new_convo = [ + { "role" : aliases_mapping[message[role_key]], "content" : message[content_key], } + for message in convo + ] all_convos.append(new_convo) pass - return { conversation_key : all_convos, } + return { "conversations" : all_convos, } pass - return dataset.map(_standardize_dataset, batched = True,) + return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format") pass @@ -837,7 +1023,7 @@ def construct_chat_template( \ tokenizer = None, -template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> +chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|> @@ -851,7 +1037,7 @@ def construct_chat_template( \ default_system_message = \ "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.", - + extra_eos_tokens = None, ): @@ -865,6 +1051,7 @@ def construct_chat_template( \ assert(tokenizer is not None) if extra_eos_tokens is None: extra_eos_tokens = [] + elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,] vocab = tokenizer.get_vocab() for extra_eos in extra_eos_tokens: @@ -883,11 +1070,30 @@ def construct_chat_template( \ "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\ "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n" + # Check for EOS after {OUTPUT} + if tokenizer.eos_token is not None: + extra_eos_tokens.insert(0, tokenizer.eos_token) + if len(extra_eos_tokens) == 0: + raise RuntimeError( + "Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!" + ) + pass + + count_eos = 0 + for eos in extra_eos_tokens: + count_eos += len(re.findall(r"{OUTPUT}" + eos.encode("unicode-escape").decode("utf-8"), chat_template)) + pass + if count_eos == 0: + logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.") + eos = extra_eos_tokens[0] + chat_template = re.sub(r"{OUTPUT}", r"{OUTPUT}" + eos.encode("unicode-escape").decode("utf-8"), chat_template) + pass + # O(N^2) search finding 2 repeatted pieces of text - j = len(template)-1 + j = len(chat_template)-1 at_least_one = False while j > 0: - found = template.rfind(template[j:], 0, j) + found = chat_template.rfind(chat_template[j:], 0, j) if found == -1: break j -= 1 at_least_one = True @@ -895,19 +1101,18 @@ def construct_chat_template( \ if j > 0: j += 1 else: raise RuntimeError(error_msg) - if not at_least_one: raise RuntimeError(error_msg) # Repeatted text - instruction_response = template[j:] + instruction_response = chat_template[j:] if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1: raise RuntimeError(error_msg) pass # 1st System, Instruction, Output pair - left = template[:j] + left = chat_template[:j] # 2nd Instruction, Output pair - right = template[j:] + right = chat_template[j:] # Isolate input extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens) @@ -952,7 +1157,12 @@ def construct_chat_template( \ ollama_system = ollama_system[len(tokenizer.bos_token):] pass pass - system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}" + # Check system + if "{SYSTEM}" in ollama_system: + system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}" + else: + system_modelfile = ollama_system + pass input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}" output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}") @@ -1005,6 +1215,14 @@ def process(part, which, content = "message['content']"): partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']") partial_system = partial_system.replace("{SYSTEM}", "") + # If {SYSTEM} is non existent, simply just use the content + if "{SYSTEM}" not in partial_system: + partial_system = "messages[0]['content']" + else: + if default_system_message is None: + raise RuntimeError("Unsloth: Please specify a default system message!") + pass + # Separate the BOS if has_bos_token: partial_system = partial_system.replace(tokenizer.bos_token, "", 1) @@ -1015,10 +1233,14 @@ def process(part, which, content = "message['content']"): "{{ " + partial_system + " }}"\ "{% set loop_messages = messages[1:] %}" if default_system_message is not None: + full_system = system_part.replace("{SYSTEM}", default_system_message) partial_system += "{% else %}"\ - "{{ '" + system_part.replace("{SYSTEM}", default_system_message) + "' }}"\ + "{{ '" + full_system + "' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}" + + # Add to modelfile + modelfile += '\nSYSTEM "' + full_system + '"' else: partial_system += "{% endif %}" pass @@ -1075,6 +1297,53 @@ def test_construct_chat_template(): pass +def apply_chat_template( \ + +dataset, +tokenizer = None, + +chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{OUTPUT}<|eot_id|>""", + +default_system_message = \ + "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.", + +extra_eos_tokens = None, + +): + """ + Creates a Ollama modelfile and a HF Jinja template from a custom + template. You must provide 2x examples of an input & output. + There is an optional system message as well. + + You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional. + """ + modelfile, jinja_template = construct_chat_template( + tokenizer = tokenizer, + chat_template = chat_template, + default_system_message = default_system_message, + extra_eos_tokens = extra_eos_tokens, + ) + def formatting_prompts_func(examples): + convos = examples["conversations"] + texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos] + return { "text" : texts, } + pass + tokenizer.chat_template = jinja_template + tokenizer._ollama_modelfile = modelfile + return dataset.map(formatting_prompts_func, batched = True,) +pass + + def create_ollama_modelfile(tokenizer, gguf_location): """ Creates an Ollama Modelfile. diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 022be109..9db7fcf2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1430,15 +1430,30 @@ def get_peft_model( check_parameters = [ "r", "lora_alpha", "lora_dropout", "bias", "layers_to_transform", "layers_pattern", - "use_rslora", "modules_to_save", "init_lora_weights", + "use_rslora", "init_lora_weights", ] check_all = True for param in check_parameters: check_all = check_all and (peft_config[param] == eval(param)) pass + + # Check save_modules + old_target_modules = list(peft_config["target_modules"]) + modules_to_save = peft_config["modules_to_save"] + if modules_to_save is None: modules_to_save = {} + modules_to_save = list(modules_to_save) + old_target_modules += modules_to_save + + # Combine all + new_target_modules = list(target_modules) + \ + list(modules_to_save if modules_to_save is not None else []) + + # Now check! + new_target_modules = set(new_target_modules) check_all = check_all and ( - len(set(peft_config["target_modules"]) ^ set(target_modules)) == 0 + len(set(old_target_modules) ^ new_target_modules) == 0 ) + check_all = check_all and ( (loftq_config == {} or loftq_config is None) and \ (peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None) @@ -1449,6 +1464,35 @@ def get_peft_model( logger.warning( "Unsloth: Already have LoRA adapters! We shall skip this step." ) + + # Offload! + # [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!) + if "embed_tokens" in new_target_modules: + print("Unsloth: Casting embed_tokens to float32") + + model.model.model.embed_tokens.modules_to_save.default\ + .to(device = device, dtype = torch.float32, non_blocking = True) + model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) + + # [TODO] Move old embed_tokens to CPU - should be disk! + model.model.model.embed_tokens.original_module\ + .to(device = "cpu", non_blocking = True) + model.model.model.embed_tokens.original_module.requires_grad_(False) + pass + + if "lm_head" in new_target_modules: + print("Unsloth: Casting lm_head to float32") + + model.model.lm_head.modules_to_save.default\ + .to(device = device, dtype = torch.float32, non_blocking = True) + model.model.lm_head.modules_to_save.default.requires_grad_(True) + + # [TODO] Move old lm_head to CPU - should be disk! + model.model.lm_head.original_module\ + .to(device = "cpu", non_blocking = True) + model.model.lm_head.original_module.requires_grad_(False) + pass + return model else: raise TypeError( @@ -1669,7 +1713,7 @@ def get_peft_model( print("Unsloth: Casting embed_tokens to float32") assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) model.model.model.embed_tokens.modules_to_save.default\ - .to(device = input_embeddings_device, dtype = torch.float32, non_blocking = True) + .to(device = device, dtype = torch.float32, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass @@ -1677,7 +1721,7 @@ def get_peft_model( print("Unsloth: Casting lm_head to float32") assert(hasattr(model.model.lm_head, "modules_to_save")) model.model.lm_head.modules_to_save.default\ - .to(device = output_embeddings_device, dtype = torch.float32, non_blocking = True) + .to(device = device, dtype = torch.float32, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 57624e6d..fe2dc06c 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -735,7 +735,8 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps = 1e-16): raise ValueError( 'Unsloth: Untrained tokens found, but embed_tokens & lm_head not trainable, causing NaNs. '\ 'Restart then add `embed_tokens` & `lm_head` to '\ - '`FastLanguageModel.get_peft_model(target_modules = [..., "embed_tokens", "lm_head",])`', + '`FastLanguageModel.get_peft_model(target_modules = [..., "embed_tokens", "lm_head",]). `'\ + 'Are you using the `base` model? Instead, use the `instruct` version to silence this warning.', ) pass