diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 1b122fc8..49b8ba39 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -374,7 +374,8 @@ def prepare_n_gradient_checkpoints( # Unsloth only works on NVIDIA GPUs for now device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," -device = f"cuda:{device_ids[:device_ids.find(',')]}" +device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now +device = f"cuda:{device if device.isdigit() else '0'}" class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): """ diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 98502836..0cc047d2 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -38,11 +38,9 @@ GemmaFlashAttention2 = GemmaAttention pass -# Unsloth currently only works on one GPU import os device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," -device = f"cuda:{device_ids[:device_ids.find(',')]}" -# Please obtain a commercial license +device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now torch_nn_functional_gelu = torch.nn.functional.gelu def fast_geglu_inference(self, X): diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9327b1bb..f2f79de8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -76,7 +76,8 @@ def original_apply_o(self, X): import os # Unsloth only works on NVIDIA GPUs for now device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," -device = f"cuda:{device_ids[:device_ids.find(',')]}" +device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now +device = f"cuda:{device if device.isdigit() else '0'}" from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size @@ -846,7 +847,8 @@ def _CausalLM_fast_forward( shift_logits = logits if not hasattr(self, "extra_ignored_labels"): device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," - device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now + device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now + device = f"cuda:{device if device.isdigit() else '0'}" # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = device) pass @@ -1828,7 +1830,8 @@ def patch_peft_model( # Fixes https://github.com/unslothai/unsloth/issues/10 max_seq_length = model.max_seq_length device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," - device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now + device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now + device = f"cuda:{device if device.isdigit() else '0'}" extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = device) model.model.extra_ignored_labels = extra_ignored_labels internal_model = model diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index e147f215..832189be 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -240,7 +240,8 @@ def MistralForCausalLM_fast_forward( shift_logits = logits if not hasattr(self, "extra_ignored_labels"): device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," - device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now + device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now + device = f"cuda:{device if device.isdigit() else '0'}" # Fixes https://github.com/unslothai/unsloth/issues/10 self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = device) pass diff --git a/unsloth/save.py b/unsloth/save.py index cae59cae..940feb40 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -418,6 +418,11 @@ def unsloth_save_model( print("Unsloth: Saving model...", end = "") if save_method != "lora": print(" This might take 10 minutes for Llama-7b...", end = "") + # [TODO] Is this correct? + if save_method == "lora": + save_pretrained_settings["selected_adapters"] = None + pass + model.save_pretrained(**save_pretrained_settings) if push_to_hub and hasattr(model, "config"): @@ -649,8 +654,9 @@ def unsloth_save_model( model.config = new_config # Save! - - save_pretrained_settings["selected_adapters"] = None + # [TODO] --> is this correct? + # save_pretrained_settings["selected_adapters"] = None + # Check if pushing to an organization if save_pretrained_settings["push_to_hub"] and (username != actual_username): print(f"Unsloth: Saving to organization with address {new_save_directory}") @@ -834,7 +840,7 @@ def save_to_gguf( model_dtype : str, is_sentencepiece : bool = False, model_directory : str = "unsloth_finetuned_model", - quantization_method : str = "fast_quantized", + quantization_method = "fast_quantized", # Can be a list of options! ["q4_k_m", "q8_0", "q5_k_m"] first_conversion : str = None, _run_installer = None, # Non blocking install of llama.cpp ): @@ -846,6 +852,10 @@ def save_to_gguf( assert(model_dtype == "float16" or model_dtype == "bfloat16") model_dtype = "f16" if model_dtype == "float16" else "bf16" + # Convert quantization_method to list + quantization_method = \ + quantization_method if type(quantization_method) is list else list(quantization_method) + # Check if bfloat16 is supported if model_dtype == "bf16" and not torch.cuda.is_bf16_supported(): logger.warning( @@ -860,8 +870,11 @@ def save_to_gguf( first_conversion = model_dtype pass - if quantization_method.startswith("iq2"): - raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!") + # Check I quants + for quant_method in quantization_method: + if quant_method.startswith("iq2"): + raise RuntimeError("Unsloth: Currently iq2 type quantizations aren't supported yet - sorry!") + pass # Careful convert.py is only for Llama / Mistral based archs use_fast_convert = False @@ -871,25 +884,32 @@ def save_to_gguf( pass logger.warning_once(f"Unsloth: Converting {model_type} model. Can use fast conversion = {use_fast_convert}.") - if quantization_method == "not_quantized": quantization_method = model_dtype - elif quantization_method == "fast_quantized": quantization_method = "q8_0" - elif quantization_method == "quantized": quantization_method = "q4_k_m" - elif quantization_method is None: quantization_method = "q8_0" - pass + # Map quant methods + new_quantization_method = [] + for quant_method in quantization_method: + if quant_method == "not_quantized": quantization_method = model_dtype + elif quant_method == "fast_quantized": quantization_method = "q8_0" + elif quant_method == "quantized": quantization_method = "q4_k_m" + elif quant_method is None: quantization_method = "q8_0" + + # Check if wrong method + if quant_method not in ALLOWED_QUANTS.keys(): + error = f"Unsloth: Quant method = [{quant_method}] not supported. Choose from below:\n" + for key, value in ALLOWED_QUANTS.items(): + error += f"[{key}] => {value}\n" + raise RuntimeError(error) + pass - if quantization_method not in ALLOWED_QUANTS.keys(): - error = f"Unsloth: Quant method = [{quantization_method}] not supported. Choose from below:\n" - for key, value in ALLOWED_QUANTS.items(): - error += f"[{key}] => {value}\n" - raise RuntimeError(error) + new_quantization_method.append(quant_method) pass + quantization_method = new_quantization_method print_info = \ f"==((====))== Unsloth: Conversion from QLoRA to GGUF information\n"\ f" \\\ /| [0] Installing llama.cpp will take 3 minutes.\n"\ f"O^O/ \_/ \\ [1] Converting HF to GUUF 16bits will take 3 minutes.\n"\ - f"\ / [2] Converting GGUF 16bits to {quantization_method} will take 20 minutes.\n"\ - f' "-____-" In total, you will have to wait around 26 minutes.\n' + f"\ / [2] Converting GGUF 16bits to {quantization_method} will take 10 minutes each.\n"\ + f' "-____-" In total, you will have to wait at least 16 minutes.\n' print(print_info) # Check first_conversion format @@ -928,24 +948,37 @@ def save_to_gguf( install_llama_cpp_old(-10) pass - if quantization_method == "f32": first_conversion = "f32" - elif quantization_method == "f16": first_conversion = "f16" - elif quantization_method == "bf16": first_conversion = "bf16" - elif quantization_method == "q8_0": first_conversion = "q8_0" - else: - # Quantized models must have f16 as the default argument - if first_conversion == "f32" : pass - elif first_conversion == "f16" : pass - elif first_conversion == "bf16" : pass - elif first_conversion == "q8_0": - logger.warning_once( - "Unsloth: Using q8_0 for the `first_conversion` will lose a bit of accuracy, "\ - "but saves disk space!" - ) - # first_conversion = "f16" + # Determine maximum first_conversion state + if first_conversion == "f32" : strength = 3 + elif first_conversion == "f16" : strength = 2 + elif first_conversion == "bf16" : strength = 1 + elif first_conversion == "q8_0" : strength = 0 + + for quant_method in quantization_method: + if quant_method == "f32": strength = max(strength, 3) + elif quant_method == "f16": strength = max(strength, 2) + elif quant_method == "bf16": strength = max(strength, 1) + elif quant_method == "q8_0": strength = max(strength, 0) + else: + # Quantized models must have f16 as the default argument + if first_conversion == "f32" : pass + elif first_conversion == "f16" : pass + elif first_conversion == "bf16" : pass + elif first_conversion == "q8_0": + logger.warning_once( + "Unsloth: Using q8_0 for the `first_conversion` will lose a bit of accuracy, "\ + "but saves disk space!" + ) + # first_conversion = "f16" + pass pass pass + if strength >= 3: first_conversion = "f32" + elif strength >= 2: first_conversion = "f16" + elif strength >= 1: first_conversion = "bf16" + else: first_conversion = "q8_0" + # Non llama/mistral needs can only use f32 or f16 if not use_fast_convert and \ (first_conversion != "f16" or first_conversion != "bf16" or first_conversion != "f32"): @@ -1033,52 +1066,58 @@ def save_to_gguf( pass print(f"Unsloth: Conversion completed! Output location: {final_location}") - if quantization_method != first_conversion: - old_location = final_location - print(f"Unsloth: [2] Converting GGUF 16bit into {quantization_method}. This will take 20 minutes...") - final_location = f"./{model_directory}-unsloth.{quantization_method.upper()}.gguf" + full_precision_location = final_location - command = f"./{quantize_location} {old_location} "\ - f"{final_location} {quantization_method} {n_cpus}" - - # quantize uses stderr - with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: - for line in sp.stdout: - line = line.decode("utf-8", errors = "replace") - if "undefined reference" in line: - raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") - print(line, flush = True, end = "") - if sp.returncode is not None and sp.returncode != 0: - raise subprocess.CalledProcessError(sp.returncode, sp.args) - pass + all_saved_locations = [] + # Convert each type! + for quant_method in quantization_method: + if quant_method != first_conversion: + print(f"Unsloth: [2] Converting GGUF 16bit into {quant_method}. This will take 20 minutes...") + final_location = f"./{model_directory}-unsloth.{quant_method.upper()}.gguf" - # Check if quantization succeeded! - if not os.path.isfile(final_location): - if IS_KAGGLE_ENVIRONMENT: - raise RuntimeError( - f"Unsloth: Quantization failed for {final_location}\n"\ - "You are in a Kaggle environment, which might be the reason this is failing.\n"\ - "Kaggle only provides 20GB of disk space. Merging to 16bit for 7b models use 16GB of space.\n"\ - "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ - "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ - "I suggest you to save the 16bit model first, then use manual llama.cpp conversion." - ) - else: - raise RuntimeError( - "Unsloth: Quantization failed! You might have to compile llama.cpp yourself, then run this again.\n"\ - "You do not need to close this Python program. Run the following commands in a new terminal:\n"\ - "You must run this in the same folder as you're saving your model.\n"\ - "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\ - "cd llama.cpp && make clean && make all -j\n"\ - "Once that's done, redo the quantization." - ) + command = f"./{quantize_location} {full_precision_location} "\ + f"{final_location} {quant_method} {n_cpus}" + + # quantize uses stderr + with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp: + for line in sp.stdout: + line = line.decode("utf-8", errors = "replace") + if "undefined reference" in line: + raise RuntimeError("Failed compiling llama.cpp. Please report this ASAP!") + print(line, flush = True, end = "") + if sp.returncode is not None and sp.returncode != 0: + raise subprocess.CalledProcessError(sp.returncode, sp.args) pass - pass - print(f"Unsloth: Conversion completed! Output location: {final_location}") + # Check if quantization succeeded! + if not os.path.isfile(final_location): + if IS_KAGGLE_ENVIRONMENT: + raise RuntimeError( + f"Unsloth: Quantization failed for {final_location}\n"\ + "You are in a Kaggle environment, which might be the reason this is failing.\n"\ + "Kaggle only provides 20GB of disk space. Merging to 16bit for 7b models use 16GB of space.\n"\ + "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ + "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ + "I suggest you to save the 16bit model first, then use manual llama.cpp conversion." + ) + else: + raise RuntimeError( + "Unsloth: Quantization failed! You might have to compile llama.cpp yourself, then run this again.\n"\ + "You do not need to close this Python program. Run the following commands in a new terminal:\n"\ + "You must run this in the same folder as you're saving your model.\n"\ + "git clone --recursive https://github.com/ggerganov/llama.cpp\n"\ + "cd llama.cpp && make clean && make all -j\n"\ + "Once that's done, redo the quantization." + ) + pass + pass + + print(f"Unsloth: Conversion completed! Output location: {final_location}") + all_saved_locations.append(final_location) + pass pass - return final_location + return all_saved_locations pass @@ -1453,7 +1492,7 @@ def unsloth_save_pretrained_gguf( is_sentencepiece_model = check_if_sentencepiece_model(self) # Save to GGUF - file_location = save_to_gguf(model_type, model_dtype, is_sentencepiece_model, + all_file_locations = save_to_gguf(model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) @@ -1466,14 +1505,17 @@ def unsloth_save_pretrained_gguf( if push_to_hub: print("Unsloth: Uploading GGUF to Huggingface Hub...") - username = upload_to_huggingface( - self, save_directory, token, - "GGUF converted", "gguf", file_location, old_username, private, - ) - link = f"{username}/{new_save_directory.lstrip('/.')}" \ - if username not in new_save_directory else \ - new_save_directory.lstrip('/.') - print(f"Saved GGUF to https://huggingface.co/{link}") + + for file_location in all_file_locations: + username = upload_to_huggingface( + self, save_directory, token, + "GGUF converted", "gguf", file_location, old_username, private, + ) + link = f"{username}/{new_save_directory.lstrip('/.')}" \ + if username not in new_save_directory else \ + new_save_directory.lstrip('/.') + print(f"Saved GGUF to https://huggingface.co/{link}") + pass pass pass @@ -1604,20 +1646,22 @@ def unsloth_push_to_hub_gguf( is_sentencepiece_model = check_if_sentencepiece_model(self) # Save to GGUF - file_location = save_to_gguf(model_type, model_dtype, is_sentencepiece_model, + all_file_locations = save_to_gguf(model_type, model_dtype, is_sentencepiece_model, new_save_directory, quantization_method, first_conversion, makefile, ) - print("Unsloth: Uploading GGUF to Huggingface Hub...") - username = upload_to_huggingface( - self, repo_id, token, - "GGUF converted", "gguf", file_location, old_username, private, - ) - link = f"{username}/{new_save_directory.lstrip('/.')}" \ - if username not in new_save_directory else \ - new_save_directory.lstrip('/.') + for file_location in all_file_locations: + print("Unsloth: Uploading GGUF to Huggingface Hub...") + username = upload_to_huggingface( + self, repo_id, token, + "GGUF converted", "gguf", file_location, old_username, private, + ) + link = f"{username}/{new_save_directory.lstrip('/.')}" \ + if username not in new_save_directory else \ + new_save_directory.lstrip('/.') - print(f"Saved GGUF to https://huggingface.co/{link}") + print(f"Saved GGUF to https://huggingface.co/{link}") + pass if fix_bos_token: logger.warning(