diff --git a/README.md b/README.md index 24a993792..46a661dff 100644 --- a/README.md +++ b/README.md @@ -494,4 +494,6 @@ If you come across a `FileNotFoundError`, it is likely due to an installation is - Add missing LR number of cycles and LR power to Dreambooth and TI scripts - Fix issue with conv_block_dims and conv_block_alphas - Fix 0 noise offset issue - - Implement Stop training button on LoRA \ No newline at end of file + - Implement Stop training button on LoRA + - Add support to extract LoRA from SDXL finetuned models + - Add support for PagedAdamW8bit and PagedLion8bit optimizer. Those require a new version of bitsandbytes so success on some systems may vary. I had to uninstall all my nvidia drivers and othe cuda toolkit install, delete all cuda variable references and re-install cuda toolkit v11.8.0 to get things to work... so not super easy. \ No newline at end of file diff --git a/examples/finetune_latent.ps1 b/examples/finetune_latent.ps1 new file mode 100644 index 000000000..87ea35e4c --- /dev/null +++ b/examples/finetune_latent.ps1 @@ -0,0 +1,11 @@ +# Command 1: merge_captions_to_metadata.py +$captionExtension = "--caption_extension=.txt" +$sourceDir1 = "d:\test\1_1960-1969" +$targetFile1 = "d:\test\1_1960-1969/meta_cap.json" + +# Command 2: prepare_buckets_latents.py +$targetLatentFile = "d:\test\1_1960-1969/meta_lat.json" +$modelFile = "E:\models\sdxl\sd_xl_base_0.9.safetensors" + +./venv/Scripts/python.exe finetune/merge_captions_to_metadata.py $captionExtension $sourceDir1 $targetFile1 --full_path +./venv/Scripts/python.exe finetune/prepare_buckets_latents.py $sourceDir1 $targetFile1 $targetLatentFile $modelFile --batch_size=4 --max_resolution=1024,1024 --min_bucket_reso=64 --max_bucket_reso=2048 --mixed_precision=bf16 --full_path diff --git a/finetune_gui.py b/finetune_gui.py index 0827c8249..9d650f2a2 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -89,6 +89,7 @@ def save_configuration( caption_extension, # use_8bit_adam, xformers, + sdpa, clip_skip, save_state, resume, @@ -209,6 +210,7 @@ def open_configuration( caption_extension, # use_8bit_adam, xformers, + sdpa, clip_skip, save_state, resume, @@ -326,6 +328,7 @@ def train_model( caption_extension, # use_8bit_adam, xformers, + sdpa, clip_skip, save_state, resume, @@ -575,6 +578,7 @@ def train_model( gradient_checkpointing=gradient_checkpointing, full_fp16=full_fp16, xformers=xformers, + spda=sdpa, # use_8bit_adam=use_8bit_adam, keep_tokens=keep_tokens, persistent_data_loader_workers=persistent_data_loader_workers, @@ -866,6 +870,7 @@ def finetune_tab(headless=False): source_model.save_model_as, basic_training.caption_extension, advanced_training.xformers, + advanced_training.sdpa, advanced_training.clip_skip, advanced_training.save_state, advanced_training.resume, diff --git a/library/class_advanced_training.py b/library/class_advanced_training.py index 6684d7d6d..cc5adba7c 100644 --- a/library/class_advanced_training.py +++ b/library/class_advanced_training.py @@ -101,13 +101,14 @@ def noise_offset_type_change(noise_offset_type): # use_8bit_adam = gr.Checkbox( # label='Use 8bit adam', value=False, visible=False # ) - self.xformers = gr.Checkbox(label='Use xformers', value=True) + self.xformers = gr.Checkbox(label='Use xformers', value=True, info='Use xformers for CrossAttention') self.color_aug = gr.Checkbox(label='Color augmentation', value=False) self.flip_aug = gr.Checkbox(label='Flip augmentation', value=False) self.min_snr_gamma = gr.Slider( label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1 ) with gr.Row(): + self.sdpa = gr.Checkbox(label='Use sdpa', value=False, info='Use sdpa for CrossAttention') self.bucket_no_upscale = gr.Checkbox( label="Don't upscale bucket resolution", value=True ) diff --git a/library/common_gui.py b/library/common_gui.py index 32ad66c76..3881e9096 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -735,6 +735,10 @@ def run_cmd_advanced_training(**kwargs): xformers = kwargs.get('xformers') if xformers: run_cmd += ' --xformers' + + sdpa = kwargs.get('sdpa') + if sdpa: + run_cmd += ' --sdpa' persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers') if persistent_data_loader_workers: @@ -856,3 +860,18 @@ def check_duplicate_filenames(folder_path, image_extension = ['.gif', '.png', '. print(f"Current file: {full_path}") else: filenames[filename] = full_path + +def is_file_writable(file_path): + if not os.path.exists(file_path): + # print(f"File '{file_path}' does not exist.") + return True + + try: + log.warning(f"File '{file_path}' already exist... it will be overwritten...") + # Check if the file can be opened in write mode (which implies it's not open by another process) + with open(file_path, 'a'): + pass + return True + except IOError: + log.warning(f"File '{file_path}' can't be written to...") + return False \ No newline at end of file diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index 45ddb5f6d..1ea48b60d 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -4,8 +4,8 @@ import os from .common_gui import ( get_saveasfilename_path, - get_any_file_path, get_file_path, + is_file_writable ) from library.custom_logging import setup_logging @@ -27,25 +27,31 @@ def extract_lora( save_precision, dim, v2, + sdxl, conv_dim, + clamp_quantile, + min_diff, device, ): # Check for caption_text_input if model_tuned == '': - msgbox('Invalid finetuned model file') + log.info('Invalid finetuned model file') return if model_org == '': - msgbox('Invalid base model file') + log.info('Invalid base model file') return # Check if source model exist if not os.path.isfile(model_tuned): - msgbox('The provided finetuned model is not a file') + log.info('The provided finetuned model is not a file') return if not os.path.isfile(model_org): - msgbox('The provided base model is not a file') + log.info('The provided base model is not a file') + return + + if not is_file_writable(save_to): return run_cmd = ( @@ -61,6 +67,10 @@ def extract_lora( run_cmd += f' --conv_dim {conv_dim}' if v2: run_cmd += f' --v2' + if sdxl: + run_cmd += f' --sdxl' + run_cmd += f' --clamp_quantile {clamp_quantile}' + run_cmd += f' --min_diff {min_diff}' log.info(run_cmd) @@ -160,7 +170,19 @@ def gradio_extract_lora_tab(headless=False): step=1, interactive=True, ) + clamp_quantile = gr.Number( + label='Clamp Quantile', + value=1, + interactive=True, + ) + min_diff = gr.Number( + label='Minimum difference', + value=0.01, + interactive=True, + ) + with gr.Row(): v2 = gr.Checkbox(label='v2', value=False, interactive=True) + sdxl = gr.Checkbox(label='SDXL', value=False, interactive=True) device = gr.Dropdown( label='Device', choices=[ @@ -182,7 +204,10 @@ def gradio_extract_lora_tab(headless=False): save_precision, dim, v2, + sdxl, conv_dim, + clamp_quantile, + min_diff, device, ], show_progress=False, diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 0bc1afe0a..fb0ba7cc6 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -12,10 +12,8 @@ import library.sdxl_model_util as sdxl_model_util import lora - -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-4 - +# CLAMP_QUANTILE = 1 +# MIN_DIFF = 1e-2 def save_to_file(file_name, model, state_dict, dtype): if dtype is not None: @@ -91,9 +89,9 @@ def str_to_dtype(p): diff = module_t.weight - module_o.weight # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > args.min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {args.min_diff}") diff = diff.float() diffs[lora_name] = diff @@ -149,7 +147,7 @@ def str_to_dtype(p): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) + hi_val = torch.quantile(dist, args.clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) @@ -243,6 +241,18 @@ def setup_parser() -> argparse.ArgumentParser: help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", ) parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--clamp_quantile", + type=float, + default=1, + help="Quantile clamping value, float, (0-1). Defailt = 1", + ) + parser.add_argument( + "--min_diff", + type=float, + default=1, + help="Minimum difference betwen finetuned model and base to consider them different enough to extract, float, (0-1). Defailt = 0.01", + ) return parser