diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 6d5664dec..998a53b38 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -1,916 +1,16 @@ -# v1: initial release -# v2: add open and save folder icons -# v3: Add new Utilities tab for Dreambooth folder preparation -# v3.1: Adding captionning of images to utilities - +import argparse import gradio as gr -import json -import math import os -import subprocess -import pathlib -import argparse -from datetime import datetime -from library.common_gui import ( - get_file_path, - get_saveasfile_path, - color_aug_changed, - save_inference_file, - run_cmd_advanced_training, - update_my_data, - check_if_model_exist, - output_message, - verify_image_folder_pattern, - SaveConfigFile, - save_to_file, -) -from library.class_configuration_file import ConfigurationFile -from library.class_source_model import SourceModel -from library.class_basic_training import BasicTraining -from library.class_advanced_training import AdvancedTraining -from library.class_folders import Folders -from library.class_command_executor import CommandExecutor -from library.class_sdxl_parameters import SDXLParameters -from library.tensorboard_gui import ( - gradio_tensorboard, - start_tensorboard, - stop_tensorboard, -) -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab -from library.utilities import utilities_tab -from library.class_sample_images import SampleImages, run_cmd_sample - -from library.custom_logging import setup_logging -from library.localization_ext import add_javascript - -# Set up logging -log = setup_logging() - -# Setup command executor -executor = CommandExecutor() - - -def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - learning_rate_te, - learning_rate_te1, - learning_rate_te2, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - full_bf16, - no_token_padding, - stop_text_encoder_training, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - output_name, - max_token_length, - max_train_epochs, - max_train_steps, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - weighted_captions, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - min_timestep, - max_timestep, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - original_file_path = file_path - - save_as_bool = True if save_as.get("label") == "True" else False - - if save_as_bool: - log.info("Save as...") - file_path = get_saveasfile_path(file_path) - else: - log.info("Save...") - if file_path == None or file_path == "": - file_path = get_saveasfile_path(file_path) - - if file_path == None or file_path == "": - return original_file_path # In case a file_path was provided and the user decide to cancel the open action - - # Extract the destination directory from the file path - destination_directory = os.path.dirname(file_path) - - # Create the destination directory if it doesn't exist - if not os.path.exists(destination_directory): - os.makedirs(destination_directory) - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as"], - ) - - return file_path - - -def open_configuration( - ask_for_file, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - learning_rate_te, - learning_rate_te1, - learning_rate_te2, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - full_bf16, - no_token_padding, - stop_text_encoder_training, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - output_name, - max_token_length, - max_train_epochs, - max_train_steps, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - weighted_captions, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - min_timestep, - max_timestep, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - ask_for_file = True if ask_for_file.get("label") == "True" else False - - original_file_path = file_path - - if ask_for_file: - file_path = get_file_path(file_path) - - if not file_path == "" and not file_path == None: - # load variables from JSON file - with open(file_path, "r") as f: - my_data = json.load(f) - log.info("Loading config...") - # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True - my_data = update_my_data(my_data) - else: - file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data = {} - values = [file_path] - for key, value in parameters: - # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ["ask_for_file", "file_path"]: - values.append(my_data.get(key, value)) - return tuple(values) +from kohya_gui.dreambooth_gui import dreambooth_tab +from kohya_gui.utilities import utilities_tab +from kohya_gui.custom_logging import setup_logging +from kohya_gui.localization_ext import add_javascript -def train_model( - headless, - print_only, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - learning_rate_te, - learning_rate_te1, - learning_rate_te2, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - full_bf16, - no_token_padding, - stop_text_encoder_training_pct, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - output_name, - max_token_length, - max_train_epochs, - max_train_steps, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, # Keep this. Yes, it is unused here but required given the common list used - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - weighted_captions, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - min_timestep, - max_timestep, -): - # Get list of function parameters and values - parameters = list(locals().items()) - print_only_bool = True if print_only.get("label") == "True" else False - log.info(f"Start training Dreambooth...") - - headless_bool = True if headless.get("label") == "True" else False - - if pretrained_model_name_or_path == "": - output_message( - msg="Source model information is missing", headless=headless_bool - ) - return - - if train_data_dir == "": - output_message(msg="Image folder path is missing", headless=headless_bool) - return - - if not os.path.exists(train_data_dir): - output_message(msg="Image folder does not exist", headless=headless_bool) - return - - if not verify_image_folder_pattern(train_data_dir): - return - - if reg_data_dir != "": - if not os.path.exists(reg_data_dir): - output_message( - msg="Regularisation folder does not exist", - headless=headless_bool, - ) - return - - if not verify_image_folder_pattern(reg_data_dir): - return - - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) - return - - if check_if_model_exist( - output_name, output_dir, save_model_as, headless=headless_bool - ): - return - - # if sdxl: - # output_message( - # msg='Dreambooth training is not compatible with SDXL models yet..', - # headless=headless_bool, - # ) - # return - - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - - # Get a list of all subfolders in train_data_dir, excluding hidden folders - subfolders = [ - f - for f in os.listdir(train_data_dir) - if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith(".") - ] - - # Check if subfolders are present. If not let the user know and return - if not subfolders: - log.info(f"No {subfolders} were found in train_data_dir can't train...") - return - - total_steps = 0 - - # Loop through each subfolder and extract the number of repeats - for folder in subfolders: - # Extract the number of repeats from the folder name - try: - repeats = int(folder.split("_")[0]) - except ValueError: - log.info( - f"Subfolder {folder} does not have a proper repeat value, please correct the name or remove it... can't train..." - ) - continue - - # Count the number of images in the folder - num_images = len( - [ - f - for f, lower_f in ( - (file, file.lower()) - for file in os.listdir(os.path.join(train_data_dir, folder)) - ) - if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) - ] - ) - - if num_images == 0: - log.info(f"{folder} folder contain no images, skipping...") - else: - # Calculate the total number of steps for this folder - steps = repeats * num_images - total_steps += steps - - # Print the result - log.info(f"Folder {folder} : steps {steps}") - - if total_steps == 0: - log.info(f"No images were found in folder {train_data_dir}... please rectify!") - return - - # Print the result - # log.info(f"{total_steps} total steps") - - if reg_data_dir == "": - reg_factor = 1 - else: - log.info( - f"Regularisation images are used... Will double the number of steps required..." - ) - reg_factor = 2 - - if max_train_steps == "" or max_train_steps == "0": - # calculate max_train_steps - max_train_steps = int( - math.ceil( - float(total_steps) - / int(train_batch_size) - / int(gradient_accumulation_steps) - * int(epoch) - * int(reg_factor) - ) - ) - log.info( - f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}" - ) - - # calculate stop encoder training - if int(stop_text_encoder_training_pct) == -1: - stop_text_encoder_training = -1 - elif stop_text_encoder_training_pct == None: - stop_text_encoder_training = 0 - else: - stop_text_encoder_training = math.ceil( - float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) - ) - log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") - - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - log.info(f"lr_warmup_steps = {lr_warmup_steps}") - - # run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' - run_cmd = "accelerate launch" - - run_cmd += run_cmd_advanced_training( - num_processes=num_processes, - num_machines=num_machines, - multi_gpu=multi_gpu, - gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process, - ) - - if sdxl: - run_cmd += f' "./sdxl_train.py"' - else: - run_cmd += f' "./train_db.py"' - - run_cmd += run_cmd_advanced_training( - adaptive_noise_scale=adaptive_noise_scale, - additional_parameters=additional_parameters, - bucket_no_upscale=bucket_no_upscale, - bucket_reso_steps=bucket_reso_steps, - cache_latents=cache_latents, - cache_latents_to_disk=cache_latents_to_disk, - caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, - caption_extension=caption_extension, - clip_skip=clip_skip, - color_aug=color_aug, - enable_bucket=enable_bucket, - epoch=epoch, - flip_aug=flip_aug, - full_bf16=full_bf16, - full_fp16=full_fp16, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_checkpointing=gradient_checkpointing, - keep_tokens=keep_tokens, - learning_rate=learning_rate, - learning_rate_te1=learning_rate_te1 if sdxl else None, - learning_rate_te2=learning_rate_te2 if sdxl else None, - learning_rate_te=learning_rate_te if not sdxl else None, - logging_dir=logging_dir, - lr_scheduler=lr_scheduler, - lr_scheduler_args=lr_scheduler_args, - lr_scheduler_num_cycles=lr_scheduler_num_cycles, - lr_scheduler_power=lr_scheduler_power, - lr_warmup_steps=lr_warmup_steps, - max_bucket_reso=max_bucket_reso, - max_data_loader_n_workers=max_data_loader_n_workers, - max_resolution=max_resolution, - max_timestep=max_timestep, - max_token_length=max_token_length, - max_train_epochs=max_train_epochs, - max_train_steps=max_train_steps, - mem_eff_attn=mem_eff_attn, - min_bucket_reso=min_bucket_reso, - min_snr_gamma=min_snr_gamma, - min_timestep=min_timestep, - mixed_precision=mixed_precision, - multires_noise_discount=multires_noise_discount, - multires_noise_iterations=multires_noise_iterations, - no_token_padding=no_token_padding, - noise_offset=noise_offset, - noise_offset_type=noise_offset_type, - optimizer=optimizer, - optimizer_args=optimizer_args, - output_dir=output_dir, - output_name=output_name, - persistent_data_loader_workers=persistent_data_loader_workers, - pretrained_model_name_or_path=pretrained_model_name_or_path, - prior_loss_weight=prior_loss_weight, - random_crop=random_crop, - reg_data_dir=reg_data_dir, - resume=resume, - save_every_n_epochs=save_every_n_epochs, - save_every_n_steps=save_every_n_steps, - save_last_n_steps=save_last_n_steps, - save_last_n_steps_state=save_last_n_steps_state, - save_model_as=save_model_as, - save_precision=save_precision, - save_state=save_state, - scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, - seed=seed, - shuffle_caption=shuffle_caption, - stop_text_encoder_training=stop_text_encoder_training, - train_batch_size=train_batch_size, - train_data_dir=train_data_dir, - use_wandb=use_wandb, - v2=v2, - v_parameterization=v_parameterization, - v_pred_like_loss=v_pred_like_loss, - vae=vae, - vae_batch_size=vae_batch_size, - wandb_api_key=wandb_api_key, - weighted_captions=weighted_captions, - xformers=xformers, - ) - - run_cmd += run_cmd_sample( - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - output_dir, - ) - - if print_only_bool: - log.warning( - "Here is the trainer command as a reference. It will not be executed:\n" - ) - print(run_cmd) - - save_to_file(run_cmd) - else: - # Saving config file for model - current_datetime = datetime.now() - formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") - - log.info(f"Saving training config to {file_path}...") - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as", "headless", "print_only"], - ) - - log.info(run_cmd) - - # Run the command - - executor.execute_command(run_cmd=run_cmd) - - # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f"{output_dir}/{output_name}") - - if not last_dir.is_dir(): - # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization, output_name) - - -def dreambooth_tab( - # train_data_dir=gr.Textbox(), - # reg_data_dir=gr.Textbox(), - # output_dir=gr.Textbox(), - # logging_dir=gr.Textbox(), - headless=False, -): - dummy_db_true = gr.Label(value=True, visible=False) - dummy_db_false = gr.Label(value=False, visible=False) - dummy_headless = gr.Label(value=headless, visible=False) - - with gr.Tab("Training"): - gr.Markdown("Train a custom model using kohya dreambooth python code...") - - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel(headless=headless) - - with gr.Tab("Folders"): - folders = Folders(headless=headless) - with gr.Tab("Parameters"): - with gr.Tab("Basic", elem_id="basic_tab"): - basic_training = BasicTraining( - learning_rate_value="1e-5", - lr_scheduler_value="cosine", - lr_warmup_value="10", - dreambooth=True, - sdxl_checkbox=source_model.sdxl_checkbox, - ) - - # # Add SDXL Parameters - # sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False) - - with gr.Tab("Advanced", elem_id="advanced_tab"): - advanced_training = AdvancedTraining(headless=headless) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], - ) - - with gr.Tab("Samples", elem_id="samples_tab"): - sample = SampleImages() - - with gr.Tab("Dataset Preparation"): - gr.Markdown( - "This section provide Dreambooth tools to help setup your dataset..." - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) - gradio_dataset_balancing_tab(headless=headless) - - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") - - button_stop_training = gr.Button("Stop training") - - button_print = gr.Button("Print training command") - - # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() - - button_start_tensorboard.click( - start_tensorboard, - inputs=[dummy_headless, folders.logging_dir], - show_progress=False, - ) - - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) - - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.learning_rate_te, - basic_training.learning_rate_te1, - basic_training.learning_rate_te2, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.full_fp16, - advanced_training.full_bf16, - advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - basic_training.min_bucket_reso, - basic_training.max_bucket_reso, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.vae, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, - folders.output_name, - advanced_training.max_token_length, - basic_training.max_train_epochs, - basic_training.max_train_steps, - advanced_training.max_data_loader_n_workers, - advanced_training.mem_eff_attn, - advanced_training.gradient_accumulation_steps, - source_model.model_list, - advanced_training.keep_tokens, - basic_training.lr_scheduler_num_cycles, - basic_training.lr_scheduler_power, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.v_pred_like_loss, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - basic_training.lr_scheduler_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - advanced_training.weighted_captions, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - advanced_training.min_timestep, - advanced_training.max_timestep, - ] - - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) - - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) - - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) - - button_stop_training.click(executor.kill_command) - - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) - - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) +# Set up logging +log = setup_logging() def UI(**kwargs): diff --git a/finetune_gui.py b/finetune_gui.py index 73e4b2353..73eb81a25 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -1,1091 +1,16 @@ +import argparse import gradio as gr -import json -import math import os -import subprocess -import pathlib -import argparse -from datetime import datetime -from library.common_gui import ( - get_folder_path, - get_file_path, - get_saveasfile_path, - save_inference_file, - run_cmd_advanced_training, - color_aug_changed, - update_my_data, - check_if_model_exist, - SaveConfigFile, - save_to_file, -) -from library.class_configuration_file import ConfigurationFile -from library.class_source_model import SourceModel -from library.class_basic_training import BasicTraining -from library.class_advanced_training import AdvancedTraining -from library.class_sdxl_parameters import SDXLParameters -from library.class_command_executor import CommandExecutor -from library.tensorboard_gui import ( - gradio_tensorboard, - start_tensorboard, - stop_tensorboard, -) -from library.utilities import utilities_tab -from library.class_sample_images import SampleImages, run_cmd_sample -from library.custom_logging import setup_logging -from library.localization_ext import add_javascript +from kohya_gui.utilities import utilities_tab +from kohya_gui.finetune_gui import finetune_tab + +from kohya_gui.custom_logging import setup_logging +from kohya_gui.localization_ext import add_javascript # Set up logging log = setup_logging() -# Setup command executor -executor = CommandExecutor() - -# from easygui import msgbox - -folder_symbol = "\U0001f4c2" # 📂 -refresh_symbol = "\U0001f504" # 🔄 -save_style_symbol = "\U0001f4be" # 💾 -document_symbol = "\U0001F4C4" # 📄 - -PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe" - - -def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl_checkbox, - train_dir, - image_folder, - output_dir, - logging_dir, - max_resolution, - min_bucket_reso, - max_bucket_reso, - batch_size, - flip_aug, - caption_metadata_filename, - latent_metadata_filename, - full_path, - learning_rate, - lr_scheduler, - lr_warmup, - dataset_repeats, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - learning_rate_te, - learning_rate_te1, - learning_rate_te2, - train_text_encoder, - full_bf16, - create_caption, - create_buckets, - save_model_as, - caption_extension, - # use_8bit_adam, - xformers, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - save_state, - resume, - gradient_checkpointing, - gradient_accumulation_steps, - block_lr, - mem_eff_attn, - shuffle_caption, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - full_fp16, - color_aug, - model_list, - cache_latents, - cache_latents_to_disk, - use_latent_files, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - weighted_captions, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - min_timestep, - max_timestep, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - original_file_path = file_path - - save_as_bool = True if save_as.get("label") == "True" else False - - if save_as_bool: - log.info("Save as...") - file_path = get_saveasfile_path(file_path) - else: - log.info("Save...") - if file_path == None or file_path == "": - file_path = get_saveasfile_path(file_path) - - # log.info(file_path) - - if file_path == None or file_path == "": - return original_file_path # In case a file_path was provided and the user decide to cancel the open action - - # Extract the destination directory from the file path - destination_directory = os.path.dirname(file_path) - - # Create the destination directory if it doesn't exist - if not os.path.exists(destination_directory): - os.makedirs(destination_directory) - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as"], - ) - - return file_path - - -def open_configuration( - ask_for_file, - apply_preset, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl_checkbox, - train_dir, - image_folder, - output_dir, - logging_dir, - max_resolution, - min_bucket_reso, - max_bucket_reso, - batch_size, - flip_aug, - caption_metadata_filename, - latent_metadata_filename, - full_path, - learning_rate, - lr_scheduler, - lr_warmup, - dataset_repeats, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - learning_rate_te, - learning_rate_te1, - learning_rate_te2, - train_text_encoder, - full_bf16, - create_caption, - create_buckets, - save_model_as, - caption_extension, - # use_8bit_adam, - xformers, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - save_state, - resume, - gradient_checkpointing, - gradient_accumulation_steps, - block_lr, - mem_eff_attn, - shuffle_caption, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - full_fp16, - color_aug, - model_list, - cache_latents, - cache_latents_to_disk, - use_latent_files, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - weighted_captions, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - min_timestep, - max_timestep, - training_preset, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - ask_for_file = True if ask_for_file.get("label") == "True" else False - apply_preset = True if apply_preset.get("label") == "True" else False - - # Check if we are "applying" a preset or a config - if apply_preset: - log.info(f"Applying preset {training_preset}...") - file_path = f"./presets/finetune/{training_preset}.json" - else: - # If not applying a preset, set the `training_preset` field to an empty string - # Find the index of the `training_preset` parameter using the `index()` method - training_preset_index = parameters.index(("training_preset", training_preset)) - - # Update the value of `training_preset` by directly assigning an empty string value - parameters[training_preset_index] = ("training_preset", "") - - original_file_path = file_path - - if ask_for_file: - file_path = get_file_path(file_path) - - if not file_path == "" and not file_path == None: - # load variables from JSON file - with open(file_path, "r") as f: - my_data = json.load(f) - log.info("Loading config...") - # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True - my_data = update_my_data(my_data) - else: - file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data = {} - - values = [file_path] - for key, value in parameters: - json_value = my_data.get(key) - # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ["ask_for_file", "apply_preset", "file_path"]: - values.append(json_value if json_value is not None else value) - return tuple(values) - - -def train_model( - headless, - print_only, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl_checkbox, - train_dir, - image_folder, - output_dir, - logging_dir, - max_resolution, - min_bucket_reso, - max_bucket_reso, - batch_size, - flip_aug, - caption_metadata_filename, - latent_metadata_filename, - full_path, - learning_rate, - lr_scheduler, - lr_warmup, - dataset_repeats, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - learning_rate_te, - learning_rate_te1, - learning_rate_te2, - train_text_encoder, - full_bf16, - generate_caption_database, - generate_image_buckets, - save_model_as, - caption_extension, - # use_8bit_adam, - xformers, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - save_state, - resume, - gradient_checkpointing, - gradient_accumulation_steps, - block_lr, - mem_eff_attn, - shuffle_caption, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - full_fp16, - color_aug, - model_list, # Keep this. Yes, it is unused here but required given the common list used - cache_latents, - cache_latents_to_disk, - use_latent_files, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - weighted_captions, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - min_timestep, - max_timestep, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - print_only_bool = True if print_only.get("label") == "True" else False - log.info(f"Start Finetuning...") - - headless_bool = True if headless.get("label") == "True" else False - - if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): - return - - # if float(noise_offset) > 0 and ( - # multires_noise_iterations > 0 or multires_noise_discount > 0 - # ): - # output_message( - # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", - # title='Error', - # headless=headless_bool, - # ) - # return - - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - - # create caption json file - if generate_caption_database: - if not os.path.exists(train_dir): - os.mkdir(train_dir) - - run_cmd = f"{PYTHON} finetune/merge_captions_to_metadata.py" - if caption_extension == "": - run_cmd += f' --caption_extension=".caption"' - else: - run_cmd += f" --caption_extension={caption_extension}" - run_cmd += f' "{image_folder}"' - run_cmd += f' "{train_dir}/{caption_metadata_filename}"' - if full_path: - run_cmd += f" --full_path" - - log.info(run_cmd) - - if not print_only_bool: - # Run the command - if os.name == "posix": - os.system(run_cmd) - else: - subprocess.run(run_cmd) - - # create images buckets - if generate_image_buckets: - run_cmd = f"{PYTHON} finetune/prepare_buckets_latents.py" - run_cmd += f' "{image_folder}"' - run_cmd += f' "{train_dir}/{caption_metadata_filename}"' - run_cmd += f' "{train_dir}/{latent_metadata_filename}"' - run_cmd += f' "{pretrained_model_name_or_path}"' - run_cmd += f" --batch_size={batch_size}" - run_cmd += f" --max_resolution={max_resolution}" - run_cmd += f" --min_bucket_reso={min_bucket_reso}" - run_cmd += f" --max_bucket_reso={max_bucket_reso}" - run_cmd += f" --mixed_precision={mixed_precision}" - # if flip_aug: - # run_cmd += f' --flip_aug' - if full_path: - run_cmd += f" --full_path" - if sdxl_checkbox and sdxl_no_half_vae: - log.info("Using mixed_precision = no because no half vae is selected...") - run_cmd += f' --mixed_precision="no"' - - log.info(run_cmd) - - if not print_only_bool: - # Run the command - if os.name == "posix": - os.system(run_cmd) - else: - subprocess.run(run_cmd) - - image_num = len( - [ - f - for f, lower_f in ( - (file, file.lower()) for file in os.listdir(image_folder) - ) - if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) - ] - ) - log.info(f"image_num = {image_num}") - - repeats = int(image_num) * int(dataset_repeats) - log.info(f"repeats = {str(repeats)}") - - # calculate max_train_steps - max_train_steps = int( - math.ceil( - float(repeats) - / int(train_batch_size) - / int(gradient_accumulation_steps) - * int(epoch) - ) - ) - - # Divide by two because flip augmentation create two copied of the source images - if flip_aug: - max_train_steps = int(math.ceil(float(max_train_steps) / 2)) - - log.info(f"max_train_steps = {max_train_steps}") - - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - log.info(f"lr_warmup_steps = {lr_warmup_steps}") - - run_cmd = "accelerate launch" - - run_cmd += run_cmd_advanced_training( - num_processes=num_processes, - num_machines=num_machines, - multi_gpu=multi_gpu, - gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process, - ) - - if sdxl_checkbox: - run_cmd += f' "./sdxl_train.py"' - else: - run_cmd += f' "./fine_tune.py"' - - in_json = ( - f"{train_dir}/{latent_metadata_filename}" - if use_latent_files == "Yes" - else f"{train_dir}/{caption_metadata_filename}" - ) - cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs - no_half_vae = sdxl_checkbox and sdxl_no_half_vae - - run_cmd += run_cmd_advanced_training( - adaptive_noise_scale=adaptive_noise_scale, - additional_parameters=additional_parameters, - block_lr=block_lr, - bucket_no_upscale=bucket_no_upscale, - bucket_reso_steps=bucket_reso_steps, - cache_latents=cache_latents, - cache_latents_to_disk=cache_latents_to_disk, - cache_text_encoder_outputs=cache_text_encoder_outputs - if sdxl_checkbox - else None, - caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, - caption_extension=caption_extension, - clip_skip=clip_skip, - color_aug=color_aug, - dataset_repeats=dataset_repeats, - enable_bucket=True, - flip_aug=flip_aug, - full_bf16=full_bf16, - full_fp16=full_fp16, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_checkpointing=gradient_checkpointing, - in_json=in_json, - keep_tokens=keep_tokens, - learning_rate=learning_rate, - learning_rate_te1=learning_rate_te1 if sdxl_checkbox else None, - learning_rate_te2=learning_rate_te2 if sdxl_checkbox else None, - learning_rate_te=learning_rate_te if not sdxl_checkbox else None, - logging_dir=logging_dir, - lr_scheduler=lr_scheduler, - lr_scheduler_args=lr_scheduler_args, - lr_warmup_steps=lr_warmup_steps, - max_bucket_reso=max_bucket_reso, - max_data_loader_n_workers=max_data_loader_n_workers, - max_resolution=max_resolution, - max_timestep=max_timestep, - max_token_length=max_token_length, - max_train_epochs=max_train_epochs, - max_train_steps=max_train_steps, - mem_eff_attn=mem_eff_attn, - min_bucket_reso=min_bucket_reso, - min_snr_gamma=min_snr_gamma, - min_timestep=min_timestep, - mixed_precision=mixed_precision, - multires_noise_discount=multires_noise_discount, - multires_noise_iterations=multires_noise_iterations, - no_half_vae=no_half_vae if sdxl_checkbox else None, - noise_offset=noise_offset, - noise_offset_type=noise_offset_type, - optimizer=optimizer, - optimizer_args=optimizer_args, - output_dir=output_dir, - output_name=output_name, - persistent_data_loader_workers=persistent_data_loader_workers, - pretrained_model_name_or_path=pretrained_model_name_or_path, - random_crop=random_crop, - resume=resume, - save_every_n_epochs=save_every_n_epochs, - save_every_n_steps=save_every_n_steps, - save_last_n_steps=save_last_n_steps, - save_last_n_steps_state=save_last_n_steps_state, - save_model_as=save_model_as, - save_precision=save_precision, - save_state=save_state, - scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, - seed=seed, - shuffle_caption=shuffle_caption, - train_batch_size=train_batch_size, - train_data_dir=image_folder, - train_text_encoder=train_text_encoder, - use_wandb=use_wandb, - v2=v2, - v_parameterization=v_parameterization, - v_pred_like_loss=v_pred_like_loss, - vae_batch_size=vae_batch_size, - wandb_api_key=wandb_api_key, - weighted_captions=weighted_captions, - xformers=xformers, - ) - - run_cmd += run_cmd_sample( - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - output_dir, - ) - - if print_only_bool: - log.warning( - "Here is the trainer command as a reference. It will not be executed:\n" - ) - print(run_cmd) - - save_to_file(run_cmd) - else: - # Saving config file for model - current_datetime = datetime.now() - formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") - - log.info(f"Saving training config to {file_path}...") - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as", "headless", "print_only"], - ) - - log.info(run_cmd) - - # Run the command - executor.execute_command(run_cmd=run_cmd) - - # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f"{output_dir}/{output_name}") - - if not last_dir.is_dir(): - # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization, output_name) - - -def remove_doublequote(file_path): - if file_path != None: - file_path = file_path.replace('"', "") - - return file_path - - -def finetune_tab(headless=False): - dummy_db_true = gr.Label(value=True, visible=False) - dummy_db_false = gr.Label(value=False, visible=False) - dummy_headless = gr.Label(value=headless, visible=False) - with gr.Tab("Training"): - gr.Markdown("Train a custom model using kohya finetune python code...") - - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel(headless=headless) - - with gr.Tab("Folders"): - with gr.Row(): - train_dir = gr.Textbox( - label="Training config folder", - placeholder="folder where the training configuration files will be saved", - ) - train_dir_folder = gr.Button( - folder_symbol, - elem_id="open_folder_small", - visible=(not headless), - ) - train_dir_folder.click( - get_folder_path, - outputs=train_dir, - show_progress=False, - ) - - image_folder = gr.Textbox( - label="Training Image folder", - placeholder="folder where the training images are located", - ) - image_folder_input_folder = gr.Button( - folder_symbol, - elem_id="open_folder_small", - visible=(not headless), - ) - image_folder_input_folder.click( - get_folder_path, - outputs=image_folder, - show_progress=False, - ) - with gr.Row(): - output_dir = gr.Textbox( - label="Model output folder", - placeholder="folder where the model will be saved", - ) - output_dir_input_folder = gr.Button( - folder_symbol, - elem_id="open_folder_small", - visible=(not headless), - ) - output_dir_input_folder.click( - get_folder_path, - outputs=output_dir, - show_progress=False, - ) - - logging_dir = gr.Textbox( - label="Logging folder", - placeholder="Optional: enable logging and output TensorBoard log to this folder", - ) - logging_dir_input_folder = gr.Button( - folder_symbol, - elem_id="open_folder_small", - visible=(not headless), - ) - logging_dir_input_folder.click( - get_folder_path, - outputs=logging_dir, - show_progress=False, - ) - with gr.Row(): - output_name = gr.Textbox( - label="Model output name", - placeholder="Name of the model to output", - value="last", - interactive=True, - ) - train_dir.change( - remove_doublequote, - inputs=[train_dir], - outputs=[train_dir], - ) - image_folder.change( - remove_doublequote, - inputs=[image_folder], - outputs=[image_folder], - ) - output_dir.change( - remove_doublequote, - inputs=[output_dir], - outputs=[output_dir], - ) - with gr.Tab("Dataset preparation"): - with gr.Row(): - max_resolution = gr.Textbox( - label="Resolution (width,height)", value="512,512" - ) - min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256") - max_bucket_reso = gr.Textbox( - label="Max bucket resolution", value="1024" - ) - batch_size = gr.Textbox(label="Batch size", value="1") - with gr.Row(): - create_caption = gr.Checkbox( - label="Generate caption metadata", value=True - ) - create_buckets = gr.Checkbox( - label="Generate image buckets metadata", value=True - ) - use_latent_files = gr.Dropdown( - label="Use latent files", - choices=[ - "No", - "Yes", - ], - value="Yes", - ) - with gr.Accordion("Advanced parameters", open=False): - with gr.Row(): - caption_metadata_filename = gr.Textbox( - label="Caption metadata filename", - value="meta_cap.json", - ) - latent_metadata_filename = gr.Textbox( - label="Latent metadata filename", value="meta_lat.json" - ) - with gr.Row(): - full_path = gr.Checkbox(label="Use full path", value=True) - weighted_captions = gr.Checkbox( - label="Weighted captions", value=False - ) - with gr.Tab("Parameters"): - - def list_presets(path): - json_files = [] - - for file in os.listdir(path): - if file.endswith(".json"): - json_files.append(os.path.splitext(file)[0]) - - user_presets_path = os.path.join(path, "user_presets") - if os.path.isdir(user_presets_path): - for file in os.listdir(user_presets_path): - if file.endswith(".json"): - preset_name = os.path.splitext(file)[0] - json_files.append(os.path.join("user_presets", preset_name)) - - return json_files - - training_preset = gr.Dropdown( - label="Presets", - choices=list_presets("./presets/finetune"), - elem_id="myDropdown", - ) - - with gr.Tab("Basic", elem_id="basic_tab"): - basic_training = BasicTraining( - learning_rate_value="1e-5", - finetuning=True, - sdxl_checkbox=source_model.sdxl_checkbox, - ) - - # Add SDXL Parameters - sdxl_params = SDXLParameters(source_model.sdxl_checkbox) - - with gr.Row(): - dataset_repeats = gr.Textbox(label="Dataset repeats", value=40) - train_text_encoder = gr.Checkbox( - label="Train text encoder", value=True - ) - - with gr.Tab("Advanced", elem_id="advanced_tab"): - with gr.Row(): - gradient_accumulation_steps = gr.Number( - label="Gradient accumulate steps", value="1" - ) - block_lr = gr.Textbox( - label="Block LR", - placeholder="(Optional)", - info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3", - ) - advanced_training = AdvancedTraining(headless=headless, finetuning=True) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[ - basic_training.cache_latents - ], # Not applicable to fine_tune.py - ) - - with gr.Tab("Samples", elem_id="samples_tab"): - sample = SampleImages() - - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") - - button_stop_training = gr.Button("Stop training") - - button_print = gr.Button("Print training command") - - # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() - - button_start_tensorboard.click( - start_tensorboard, - inputs=[dummy_headless, logging_dir], - ) - - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) - - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - train_dir, - image_folder, - output_dir, - logging_dir, - max_resolution, - min_bucket_reso, - max_bucket_reso, - batch_size, - advanced_training.flip_aug, - caption_metadata_filename, - latent_metadata_filename, - full_path, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - dataset_repeats, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.learning_rate_te, - basic_training.learning_rate_te1, - basic_training.learning_rate_te2, - train_text_encoder, - advanced_training.full_bf16, - create_caption, - create_buckets, - source_model.save_model_as, - basic_training.caption_extension, - advanced_training.xformers, - advanced_training.clip_skip, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, - advanced_training.save_state, - advanced_training.resume, - advanced_training.gradient_checkpointing, - gradient_accumulation_steps, - block_lr, - advanced_training.mem_eff_attn, - advanced_training.shuffle_caption, - output_name, - advanced_training.max_token_length, - basic_training.max_train_epochs, - advanced_training.max_data_loader_n_workers, - advanced_training.full_fp16, - advanced_training.color_aug, - source_model.model_list, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - use_latent_files, - advanced_training.keep_tokens, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.v_pred_like_loss, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - basic_training.lr_scheduler_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - weighted_captions, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - sdxl_params.sdxl_cache_text_encoder_outputs, - sdxl_params.sdxl_no_half_vae, - advanced_training.min_timestep, - advanced_training.max_timestep, - ] - - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, dummy_db_false, config.config_file_name] - + settings_list - + [training_preset], - outputs=[config.config_file_name] + settings_list + [training_preset], - show_progress=False, - ) - - # config.button_open_config.click( - # open_configuration, - # inputs=[dummy_db_true, dummy_db_false, config.config_file_name] + settings_list, - # outputs=[config.config_file_name] + settings_list, - # show_progress=False, - # ) - - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, dummy_db_false, config.config_file_name] - + settings_list - + [training_preset], - outputs=[config.config_file_name] + settings_list + [training_preset], - show_progress=False, - ) - - # config.button_load_config.click( - # open_configuration, - # inputs=[dummy_db_false, config.config_file_name] + settings_list, - # outputs=[config.config_file_name] + settings_list, - # show_progress=False, - # ) - - training_preset.input( - open_configuration, - inputs=[dummy_db_false, dummy_db_true, config.config_file_name] - + settings_list - + [training_preset], - outputs=[gr.Textbox()] + settings_list + [training_preset], - show_progress=False, - ) - - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) - - button_stop_training.click(executor.kill_command) - - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) - - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - with gr.Tab("Guides"): - gr.Markdown("This section provide Various Finetuning guides and information...") - top_level_path = "./docs/Finetuning/top_level.md" - if os.path.exists(top_level_path): - with open(os.path.join(top_level_path), "r", encoding="utf8") as file: - guides_top_level = file.read() + "\n" - gr.Markdown(guides_top_level) - def UI(**kwargs): add_javascript(kwargs.get("language")) diff --git a/kohya_gui.py b/kohya_gui.py index 8e58ec8ac..78be71d73 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -4,13 +4,12 @@ from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab from textual_inversion_gui import ti_tab -from library.utilities import utilities_tab +from kohya_gui.utilities import utilities_tab from lora_gui import lora_tab -from library.class_lora_tab import LoRATools +from kohya_gui.class_lora_tab import LoRATools -import os -from library.custom_logging import setup_logging -from library.localization_ext import add_javascript +from kohya_gui.custom_logging import setup_logging +from kohya_gui.localization_ext import add_javascript # Set up logging log = setup_logging() diff --git a/kohya_gui/__init__.py b/kohya_gui/__init__.py new file mode 100644 index 000000000..bc6ba3257 --- /dev/null +++ b/kohya_gui/__init__.py @@ -0,0 +1 @@ +"""empty""" diff --git a/library/basic_caption_gui.py b/kohya_gui/basic_caption_gui.py similarity index 99% rename from library/basic_caption_gui.py rename to kohya_gui/basic_caption_gui.py index 29b0bf1f6..a8e6d52c9 100644 --- a/library/basic_caption_gui.py +++ b/kohya_gui/basic_caption_gui.py @@ -4,7 +4,7 @@ from .common_gui import get_folder_path, add_pre_postfix, find_replace import os -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/blip_caption_gui.py b/kohya_gui/blip_caption_gui.py similarity index 98% rename from library/blip_caption_gui.py rename to kohya_gui/blip_caption_gui.py index 3679140c8..358f3cf05 100644 --- a/library/blip_caption_gui.py +++ b/kohya_gui/blip_caption_gui.py @@ -3,7 +3,7 @@ import subprocess import os from .common_gui import get_folder_path, add_pre_postfix -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/class_advanced_training.py b/kohya_gui/class_advanced_training.py similarity index 100% rename from library/class_advanced_training.py rename to kohya_gui/class_advanced_training.py diff --git a/library/class_basic_training.py b/kohya_gui/class_basic_training.py similarity index 100% rename from library/class_basic_training.py rename to kohya_gui/class_basic_training.py diff --git a/library/class_command_executor.py b/kohya_gui/class_command_executor.py similarity index 95% rename from library/class_command_executor.py rename to kohya_gui/class_command_executor.py index 9176533cb..770c8c407 100644 --- a/library/class_command_executor.py +++ b/kohya_gui/class_command_executor.py @@ -1,6 +1,6 @@ import subprocess import psutil -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/class_configuration_file.py b/kohya_gui/class_configuration_file.py similarity index 100% rename from library/class_configuration_file.py rename to kohya_gui/class_configuration_file.py diff --git a/library/class_dreambooth_gui.py b/kohya_gui/class_dreambooth_gui.py similarity index 97% rename from library/class_dreambooth_gui.py rename to kohya_gui/class_dreambooth_gui.py index 81288b8c7..b74c56ee9 100644 --- a/library/class_dreambooth_gui.py +++ b/kohya_gui/class_dreambooth_gui.py @@ -6,10 +6,10 @@ from .class_basic_training import BasicTraining from .class_advanced_training import AdvancedTraining from .class_sample_images import SampleImages -from library.dreambooth_folder_creation_gui import ( +from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab +from .dataset_balancing_gui import gradio_dataset_balancing_tab from .common_gui import color_aug_changed diff --git a/library/class_folders.py b/kohya_gui/class_folders.py similarity index 100% rename from library/class_folders.py rename to kohya_gui/class_folders.py diff --git a/library/class_lora_tab.py b/kohya_gui/class_lora_tab.py similarity index 64% rename from library/class_lora_tab.py rename to kohya_gui/class_lora_tab.py index df6a4f387..2f33d9860 100644 --- a/library/class_lora_tab.py +++ b/kohya_gui/class_lora_tab.py @@ -1,17 +1,17 @@ import gradio as gr -from library.merge_lora_gui import GradioMergeLoRaTab -from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab -from library.verify_lora_gui import gradio_verify_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab -from library.extract_lora_gui import gradio_extract_lora_tab -from library.convert_lcm_gui import gradio_convert_lcm_tab -from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab -from library.extract_lora_from_dylora_gui import gradio_extract_dylora_tab -from library.merge_lycoris_gui import gradio_merge_lycoris_tab +from .merge_lora_gui import GradioMergeLoRaTab +from .svd_merge_lora_gui import gradio_svd_merge_lora_tab +from .verify_lora_gui import gradio_verify_lora_tab +from .resize_lora_gui import gradio_resize_lora_tab +from .extract_lora_gui import gradio_extract_lora_tab +from .convert_lcm_gui import gradio_convert_lcm_tab +from .extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab +from .extract_lora_from_dylora_gui import gradio_extract_dylora_tab +from .merge_lycoris_gui import gradio_merge_lycoris_tab # Deprecated code -from library.dataset_balancing_gui import gradio_dataset_balancing_tab -from library.dreambooth_folder_creation_gui import ( +from .dataset_balancing_gui import gradio_dataset_balancing_tab +from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) diff --git a/library/class_sample_images.py b/kohya_gui/class_sample_images.py similarity index 98% rename from library/class_sample_images.py rename to kohya_gui/class_sample_images.py index 7ea08d03b..f7ad970d9 100644 --- a/library/class_sample_images.py +++ b/kohya_gui/class_sample_images.py @@ -3,7 +3,7 @@ import gradio as gr from easygui import msgbox -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/class_sdxl_parameters.py b/kohya_gui/class_sdxl_parameters.py similarity index 100% rename from library/class_sdxl_parameters.py rename to kohya_gui/class_sdxl_parameters.py diff --git a/library/class_source_model.py b/kohya_gui/class_source_model.py similarity index 100% rename from library/class_source_model.py rename to kohya_gui/class_source_model.py diff --git a/library/common_gui.py b/kohya_gui/common_gui.py similarity index 99% rename from library/common_gui.py rename to kohya_gui/common_gui.py index ce9b031f6..b5698e8a0 100644 --- a/library/common_gui.py +++ b/kohya_gui/common_gui.py @@ -8,7 +8,7 @@ import sys import json -from library.custom_logging import setup_logging +from .custom_logging import setup_logging from datetime import datetime # Set up logging diff --git a/library/convert_lcm_gui.py b/kohya_gui/convert_lcm_gui.py similarity index 98% rename from library/convert_lcm_gui.py rename to kohya_gui/convert_lcm_gui.py index 41b01fc54..ef15cfc3f 100644 --- a/library/convert_lcm_gui.py +++ b/kohya_gui/convert_lcm_gui.py @@ -5,7 +5,7 @@ get_saveasfilename_path, get_file_path, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/convert_model_gui.py b/kohya_gui/convert_model_gui.py similarity index 99% rename from library/convert_model_gui.py rename to kohya_gui/convert_model_gui.py index 03f7a93dd..13a639f1d 100644 --- a/library/convert_model_gui.py +++ b/kohya_gui/convert_model_gui.py @@ -5,7 +5,7 @@ import shutil from .common_gui import get_folder_path, get_file_path -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/custom_logging.py b/kohya_gui/custom_logging.py similarity index 100% rename from library/custom_logging.py rename to kohya_gui/custom_logging.py diff --git a/library/dataset_balancing_gui.py b/kohya_gui/dataset_balancing_gui.py similarity index 99% rename from library/dataset_balancing_gui.py rename to kohya_gui/dataset_balancing_gui.py index 93697f00d..f2d1c0533 100644 --- a/library/dataset_balancing_gui.py +++ b/kohya_gui/dataset_balancing_gui.py @@ -4,7 +4,7 @@ from easygui import msgbox, boolbox from .common_gui import get_folder_path -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/dreambooth_folder_creation_gui.py b/kohya_gui/dreambooth_folder_creation_gui.py similarity index 99% rename from library/dreambooth_folder_creation_gui.py rename to kohya_gui/dreambooth_folder_creation_gui.py index a2b3ec97c..66a243d3b 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/kohya_gui/dreambooth_folder_creation_gui.py @@ -4,7 +4,7 @@ import shutil import os -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py new file mode 100644 index 000000000..9203d2d0e --- /dev/null +++ b/kohya_gui/dreambooth_gui.py @@ -0,0 +1,911 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +from datetime import datetime +from .common_gui import ( + get_file_path, + get_saveasfile_path, + color_aug_changed, + save_inference_file, + run_cmd_advanced_training, + update_my_data, + check_if_model_exist, + output_message, + verify_image_folder_pattern, + SaveConfigFile, + save_to_file, +) +from .class_configuration_file import ConfigurationFile +from .class_source_model import SourceModel +from .class_basic_training import BasicTraining +from .class_advanced_training import AdvancedTraining +from .class_folders import Folders +from .class_command_executor import CommandExecutor +from .class_sdxl_parameters import SDXLParameters +from .tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from .dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from .dataset_balancing_gui import gradio_dataset_balancing_tab +from .utilities import utilities_tab +from .class_sample_images import SampleImages, run_cmd_sample +from .custom_logging import setup_logging + + +# Set up logging +log = setup_logging() + +# Setup command executor +executor = CommandExecutor() + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + learning_rate_te, + learning_rate_te1, + learning_rate_te2, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + full_bf16, + no_token_padding, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + output_name, + max_token_length, + max_train_epochs, + max_train_steps, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + weighted_captions, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + min_timestep, + max_timestep, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get("label") == "True" else False + + if save_as_bool: + log.info("Save as...") + file_path = get_saveasfile_path(file_path) + else: + log.info("Save...") + if file_path == None or file_path == "": + file_path = get_saveasfile_path(file_path) + + if file_path == None or file_path == "": + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as"], + ) + + return file_path + + +def open_configuration( + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + learning_rate_te, + learning_rate_te1, + learning_rate_te2, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + full_bf16, + no_token_padding, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + output_name, + max_token_length, + max_train_epochs, + max_train_steps, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + weighted_captions, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + min_timestep, + max_timestep, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get("label") == "True" else False + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == "" and not file_path == None: + # load variables from JSON file + with open(file_path, "r") as f: + my_data = json.load(f) + log.info("Loading config...") + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ["ask_for_file", "file_path"]: + values.append(my_data.get(key, value)) + return tuple(values) + + +def train_model( + headless, + print_only, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + learning_rate_te, + learning_rate_te1, + learning_rate_te2, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + full_bf16, + no_token_padding, + stop_text_encoder_training_pct, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + output_name, + max_token_length, + max_train_epochs, + max_train_steps, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + weighted_captions, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + min_timestep, + max_timestep, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + print_only_bool = True if print_only.get("label") == "True" else False + log.info(f"Start training Dreambooth...") + + headless_bool = True if headless.get("label") == "True" else False + + if pretrained_model_name_or_path == "": + output_message( + msg="Source model information is missing", headless=headless_bool + ) + return + + if train_data_dir == "": + output_message(msg="Image folder path is missing", headless=headless_bool) + return + + if not os.path.exists(train_data_dir): + output_message(msg="Image folder does not exist", headless=headless_bool) + return + + if not verify_image_folder_pattern(train_data_dir): + return + + if reg_data_dir != "": + if not os.path.exists(reg_data_dir): + output_message( + msg="Regularisation folder does not exist", + headless=headless_bool, + ) + return + + if not verify_image_folder_pattern(reg_data_dir): + return + + if output_dir == "": + output_message(msg="Output folder path is missing", headless=headless_bool) + return + + if check_if_model_exist( + output_name, output_dir, save_model_as, headless=headless_bool + ): + return + + # if sdxl: + # output_message( + # msg='Dreambooth training is not compatible with SDXL models yet..', + # headless=headless_bool, + # ) + # return + + # if optimizer == 'Adafactor' and lr_warmup != '0': + # output_message( + # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", + # title='Warning', + # headless=headless_bool, + # ) + # lr_warmup = '0' + + # Get a list of all subfolders in train_data_dir, excluding hidden folders + subfolders = [ + f + for f in os.listdir(train_data_dir) + if os.path.isdir(os.path.join(train_data_dir, f)) and not f.startswith(".") + ] + + # Check if subfolders are present. If not let the user know and return + if not subfolders: + log.info(f"No {subfolders} were found in train_data_dir can't train...") + return + + total_steps = 0 + + # Loop through each subfolder and extract the number of repeats + for folder in subfolders: + # Extract the number of repeats from the folder name + try: + repeats = int(folder.split("_")[0]) + except ValueError: + log.info( + f"Subfolder {folder} does not have a proper repeat value, please correct the name or remove it... can't train..." + ) + continue + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir(os.path.join(train_data_dir, folder)) + ) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) + ] + ) + + if num_images == 0: + log.info(f"{folder} folder contain no images, skipping...") + else: + # Calculate the total number of steps for this folder + steps = repeats * num_images + total_steps += steps + + # Print the result + log.info(f"Folder {folder} : steps {steps}") + + if total_steps == 0: + log.info(f"No images were found in folder {train_data_dir}... please rectify!") + return + + # Print the result + # log.info(f"{total_steps} total steps") + + if reg_data_dir == "": + reg_factor = 1 + else: + log.info( + f"Regularisation images are used... Will double the number of steps required..." + ) + reg_factor = 2 + + if max_train_steps == "" or max_train_steps == "0": + # calculate max_train_steps + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + / int(gradient_accumulation_steps) + * int(epoch) + * int(reg_factor) + ) + ) + log.info( + f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}" + ) + + # calculate stop encoder training + if int(stop_text_encoder_training_pct) == -1: + stop_text_encoder_training = -1 + elif stop_text_encoder_training_pct == None: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + ) + log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + log.info(f"lr_warmup_steps = {lr_warmup_steps}") + + # run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_db.py"' + run_cmd = "accelerate launch" + + run_cmd += run_cmd_advanced_training( + num_processes=num_processes, + num_machines=num_machines, + multi_gpu=multi_gpu, + gpu_ids=gpu_ids, + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + + if sdxl: + run_cmd += f' "./sdxl_train.py"' + else: + run_cmd += f' "./train_db.py"' + + run_cmd += run_cmd_advanced_training( + adaptive_noise_scale=adaptive_noise_scale, + additional_parameters=additional_parameters, + bucket_no_upscale=bucket_no_upscale, + bucket_reso_steps=bucket_reso_steps, + cache_latents=cache_latents, + cache_latents_to_disk=cache_latents_to_disk, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + caption_extension=caption_extension, + clip_skip=clip_skip, + color_aug=color_aug, + enable_bucket=enable_bucket, + epoch=epoch, + flip_aug=flip_aug, + full_bf16=full_bf16, + full_fp16=full_fp16, + gradient_accumulation_steps=gradient_accumulation_steps, + gradient_checkpointing=gradient_checkpointing, + keep_tokens=keep_tokens, + learning_rate=learning_rate, + learning_rate_te1=learning_rate_te1 if sdxl else None, + learning_rate_te2=learning_rate_te2 if sdxl else None, + learning_rate_te=learning_rate_te if not sdxl else None, + logging_dir=logging_dir, + lr_scheduler=lr_scheduler, + lr_scheduler_args=lr_scheduler_args, + lr_scheduler_num_cycles=lr_scheduler_num_cycles, + lr_scheduler_power=lr_scheduler_power, + lr_warmup_steps=lr_warmup_steps, + max_bucket_reso=max_bucket_reso, + max_data_loader_n_workers=max_data_loader_n_workers, + max_resolution=max_resolution, + max_timestep=max_timestep, + max_token_length=max_token_length, + max_train_epochs=max_train_epochs, + max_train_steps=max_train_steps, + mem_eff_attn=mem_eff_attn, + min_bucket_reso=min_bucket_reso, + min_snr_gamma=min_snr_gamma, + min_timestep=min_timestep, + mixed_precision=mixed_precision, + multires_noise_discount=multires_noise_discount, + multires_noise_iterations=multires_noise_iterations, + no_token_padding=no_token_padding, + noise_offset=noise_offset, + noise_offset_type=noise_offset_type, + optimizer=optimizer, + optimizer_args=optimizer_args, + output_dir=output_dir, + output_name=output_name, + persistent_data_loader_workers=persistent_data_loader_workers, + pretrained_model_name_or_path=pretrained_model_name_or_path, + prior_loss_weight=prior_loss_weight, + random_crop=random_crop, + reg_data_dir=reg_data_dir, + resume=resume, + save_every_n_epochs=save_every_n_epochs, + save_every_n_steps=save_every_n_steps, + save_last_n_steps=save_last_n_steps, + save_last_n_steps_state=save_last_n_steps_state, + save_model_as=save_model_as, + save_precision=save_precision, + save_state=save_state, + scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, + seed=seed, + shuffle_caption=shuffle_caption, + stop_text_encoder_training=stop_text_encoder_training, + train_batch_size=train_batch_size, + train_data_dir=train_data_dir, + use_wandb=use_wandb, + v2=v2, + v_parameterization=v_parameterization, + v_pred_like_loss=v_pred_like_loss, + vae=vae, + vae_batch_size=vae_batch_size, + wandb_api_key=wandb_api_key, + weighted_captions=weighted_captions, + xformers=xformers, + ) + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + if print_only_bool: + log.warning( + "Here is the trainer command as a reference. It will not be executed:\n" + ) + print(run_cmd) + + save_to_file(run_cmd) + else: + # Saving config file for model + current_datetime = datetime.now() + formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") + file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + + log.info(f"Saving training config to {file_path}...") + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as", "headless", "print_only"], + ) + + log.info(run_cmd) + + # Run the command + + executor.execute_command(run_cmd=run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f"{output_dir}/{output_name}") + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization, output_name) + + +def dreambooth_tab( + # train_data_dir=gr.Textbox(), + # reg_data_dir=gr.Textbox(), + # output_dir=gr.Textbox(), + # logging_dir=gr.Textbox(), + headless=False, +): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + dummy_headless = gr.Label(value=headless, visible=False) + + with gr.Tab("Training"): + gr.Markdown("Train a custom model using kohya dreambooth python code...") + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel(headless=headless) + + with gr.Tab("Folders"): + folders = Folders(headless=headless) + with gr.Tab("Parameters"): + with gr.Tab("Basic", elem_id="basic_tab"): + basic_training = BasicTraining( + learning_rate_value="1e-5", + lr_scheduler_value="cosine", + lr_warmup_value="10", + dreambooth=True, + sdxl_checkbox=source_model.sdxl_checkbox, + ) + + # # Add SDXL Parameters + # sdxl_params = SDXLParameters(source_model.sdxl_checkbox, show_sdxl_cache_text_encoder_outputs=False) + + with gr.Tab("Advanced", elem_id="advanced_tab"): + advanced_training = AdvancedTraining(headless=headless) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) + + with gr.Tab("Samples", elem_id="samples_tab"): + sample = SampleImages() + + with gr.Tab("Dataset Preparation"): + gr.Markdown( + "This section provide Dreambooth tools to help setup your dataset..." + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) + gradio_dataset_balancing_tab(headless=headless) + + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") + + button_stop_training = gr.Button("Stop training") + + button_print = gr.Button("Print training command") + + # Setup gradio tensorboard buttons + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=[dummy_headless, folders.logging_dir], + show_progress=False, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.learning_rate_te, + basic_training.learning_rate_te1, + basic_training.learning_rate_te2, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.full_fp16, + advanced_training.full_bf16, + advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + basic_training.min_bucket_reso, + basic_training.max_bucket_reso, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.vae, + advanced_training.num_processes, + advanced_training.num_machines, + advanced_training.multi_gpu, + advanced_training.gpu_ids, + folders.output_name, + advanced_training.max_token_length, + basic_training.max_train_epochs, + basic_training.max_train_steps, + advanced_training.max_data_loader_n_workers, + advanced_training.mem_eff_attn, + advanced_training.gradient_accumulation_steps, + source_model.model_list, + advanced_training.keep_tokens, + basic_training.lr_scheduler_num_cycles, + basic_training.lr_scheduler_power, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.v_pred_like_loss, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + basic_training.lr_scheduler_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.weighted_captions, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.min_timestep, + advanced_training.max_timestep, + ] + + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) + + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) + + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) + + button_stop_training.click(executor.kill_command) + + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) + + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) diff --git a/library/extract_lora_from_dylora_gui.py b/kohya_gui/extract_lora_from_dylora_gui.py similarity index 98% rename from library/extract_lora_from_dylora_gui.py rename to kohya_gui/extract_lora_from_dylora_gui.py index 5e84fb58b..d2851a218 100644 --- a/library/extract_lora_from_dylora_gui.py +++ b/kohya_gui/extract_lora_from_dylora_gui.py @@ -7,7 +7,7 @@ get_file_path, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/extract_lora_gui.py b/kohya_gui/extract_lora_gui.py similarity index 99% rename from library/extract_lora_gui.py rename to kohya_gui/extract_lora_gui.py index 7dde72368..c81945473 100644 --- a/library/extract_lora_gui.py +++ b/kohya_gui/extract_lora_gui.py @@ -8,7 +8,7 @@ is_file_writable, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/extract_lycoris_locon_gui.py b/kohya_gui/extract_lycoris_locon_gui.py similarity index 99% rename from library/extract_lycoris_locon_gui.py rename to kohya_gui/extract_lycoris_locon_gui.py index a655371ae..7e3ca4a5b 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/kohya_gui/extract_lycoris_locon_gui.py @@ -8,7 +8,7 @@ get_file_path, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py new file mode 100644 index 000000000..a2c23040e --- /dev/null +++ b/kohya_gui/finetune_gui.py @@ -0,0 +1,1084 @@ +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +from datetime import datetime +from .common_gui import ( + get_folder_path, + get_file_path, + get_saveasfile_path, + save_inference_file, + run_cmd_advanced_training, + color_aug_changed, + update_my_data, + check_if_model_exist, + SaveConfigFile, + save_to_file, +) +from .class_configuration_file import ConfigurationFile +from .class_source_model import SourceModel +from .class_basic_training import BasicTraining +from .class_advanced_training import AdvancedTraining +from .class_sdxl_parameters import SDXLParameters +from .class_command_executor import CommandExecutor +from .tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from .class_sample_images import SampleImages, run_cmd_sample + +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() + +# Setup command executor +executor = CommandExecutor() + +# from easygui import msgbox + +folder_symbol = "\U0001f4c2" # 📂 +refresh_symbol = "\U0001f504" # 🔄 +save_style_symbol = "\U0001f4be" # 💾 +document_symbol = "\U0001F4C4" # 📄 + +PYTHON = "python3" if os.name == "posix" else "./venv/Scripts/python.exe" + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl_checkbox, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + learning_rate_te, + learning_rate_te1, + learning_rate_te2, + train_text_encoder, + full_bf16, + create_caption, + create_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + block_lr, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, + cache_latents, + cache_latents_to_disk, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + weighted_captions, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + min_timestep, + max_timestep, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get("label") == "True" else False + + if save_as_bool: + log.info("Save as...") + file_path = get_saveasfile_path(file_path) + else: + log.info("Save...") + if file_path == None or file_path == "": + file_path = get_saveasfile_path(file_path) + + # log.info(file_path) + + if file_path == None or file_path == "": + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as"], + ) + + return file_path + + +def open_configuration( + ask_for_file, + apply_preset, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl_checkbox, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + learning_rate_te, + learning_rate_te1, + learning_rate_te2, + train_text_encoder, + full_bf16, + create_caption, + create_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + block_lr, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, + cache_latents, + cache_latents_to_disk, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + weighted_captions, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + min_timestep, + max_timestep, + training_preset, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get("label") == "True" else False + apply_preset = True if apply_preset.get("label") == "True" else False + + # Check if we are "applying" a preset or a config + if apply_preset: + log.info(f"Applying preset {training_preset}...") + file_path = f"./presets/finetune/{training_preset}.json" + else: + # If not applying a preset, set the `training_preset` field to an empty string + # Find the index of the `training_preset` parameter using the `index()` method + training_preset_index = parameters.index(("training_preset", training_preset)) + + # Update the value of `training_preset` by directly assigning an empty string value + parameters[training_preset_index] = ("training_preset", "") + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == "" and not file_path == None: + # load variables from JSON file + with open(file_path, "r") as f: + my_data = json.load(f) + log.info("Loading config...") + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + json_value = my_data.get(key) + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ["ask_for_file", "apply_preset", "file_path"]: + values.append(json_value if json_value is not None else value) + return tuple(values) + + +def train_model( + headless, + print_only, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl_checkbox, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + learning_rate, + lr_scheduler, + lr_warmup, + dataset_repeats, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + learning_rate_te, + learning_rate_te1, + learning_rate_te2, + train_text_encoder, + full_bf16, + generate_caption_database, + generate_image_buckets, + save_model_as, + caption_extension, + # use_8bit_adam, + xformers, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + save_state, + resume, + gradient_checkpointing, + gradient_accumulation_steps, + block_lr, + mem_eff_attn, + shuffle_caption, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + full_fp16, + color_aug, + model_list, # Keep this. Yes, it is unused here but required given the common list used + cache_latents, + cache_latents_to_disk, + use_latent_files, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + weighted_captions, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + min_timestep, + max_timestep, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + print_only_bool = True if print_only.get("label") == "True" else False + log.info(f"Start Finetuning...") + + headless_bool = True if headless.get("label") == "True" else False + + if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): + return + + # if float(noise_offset) > 0 and ( + # multires_noise_iterations > 0 or multires_noise_discount > 0 + # ): + # output_message( + # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", + # title='Error', + # headless=headless_bool, + # ) + # return + + # if optimizer == 'Adafactor' and lr_warmup != '0': + # output_message( + # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", + # title='Warning', + # headless=headless_bool, + # ) + # lr_warmup = '0' + + # create caption json file + if generate_caption_database: + if not os.path.exists(train_dir): + os.mkdir(train_dir) + + run_cmd = f"{PYTHON} finetune/merge_captions_to_metadata.py" + if caption_extension == "": + run_cmd += f' --caption_extension=".caption"' + else: + run_cmd += f" --caption_extension={caption_extension}" + run_cmd += f' "{image_folder}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + if full_path: + run_cmd += f" --full_path" + + log.info(run_cmd) + + if not print_only_bool: + # Run the command + if os.name == "posix": + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + # create images buckets + if generate_image_buckets: + run_cmd = f"{PYTHON} finetune/prepare_buckets_latents.py" + run_cmd += f' "{image_folder}"' + run_cmd += f' "{train_dir}/{caption_metadata_filename}"' + run_cmd += f' "{train_dir}/{latent_metadata_filename}"' + run_cmd += f' "{pretrained_model_name_or_path}"' + run_cmd += f" --batch_size={batch_size}" + run_cmd += f" --max_resolution={max_resolution}" + run_cmd += f" --min_bucket_reso={min_bucket_reso}" + run_cmd += f" --max_bucket_reso={max_bucket_reso}" + run_cmd += f" --mixed_precision={mixed_precision}" + # if flip_aug: + # run_cmd += f' --flip_aug' + if full_path: + run_cmd += f" --full_path" + if sdxl_checkbox and sdxl_no_half_vae: + log.info("Using mixed_precision = no because no half vae is selected...") + run_cmd += f' --mixed_precision="no"' + + log.info(run_cmd) + + if not print_only_bool: + # Run the command + if os.name == "posix": + os.system(run_cmd) + else: + subprocess.run(run_cmd) + + image_num = len( + [ + f + for f, lower_f in ( + (file, file.lower()) for file in os.listdir(image_folder) + ) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) + ] + ) + log.info(f"image_num = {image_num}") + + repeats = int(image_num) * int(dataset_repeats) + log.info(f"repeats = {str(repeats)}") + + # calculate max_train_steps + max_train_steps = int( + math.ceil( + float(repeats) + / int(train_batch_size) + / int(gradient_accumulation_steps) + * int(epoch) + ) + ) + + # Divide by two because flip augmentation create two copied of the source images + if flip_aug: + max_train_steps = int(math.ceil(float(max_train_steps) / 2)) + + log.info(f"max_train_steps = {max_train_steps}") + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + log.info(f"lr_warmup_steps = {lr_warmup_steps}") + + run_cmd = "accelerate launch" + + run_cmd += run_cmd_advanced_training( + num_processes=num_processes, + num_machines=num_machines, + multi_gpu=multi_gpu, + gpu_ids=gpu_ids, + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + + if sdxl_checkbox: + run_cmd += f' "./sdxl_train.py"' + else: + run_cmd += f' "./fine_tune.py"' + + in_json = ( + f"{train_dir}/{latent_metadata_filename}" + if use_latent_files == "Yes" + else f"{train_dir}/{caption_metadata_filename}" + ) + cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs + no_half_vae = sdxl_checkbox and sdxl_no_half_vae + + run_cmd += run_cmd_advanced_training( + adaptive_noise_scale=adaptive_noise_scale, + additional_parameters=additional_parameters, + block_lr=block_lr, + bucket_no_upscale=bucket_no_upscale, + bucket_reso_steps=bucket_reso_steps, + cache_latents=cache_latents, + cache_latents_to_disk=cache_latents_to_disk, + cache_text_encoder_outputs=cache_text_encoder_outputs + if sdxl_checkbox + else None, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + caption_extension=caption_extension, + clip_skip=clip_skip, + color_aug=color_aug, + dataset_repeats=dataset_repeats, + enable_bucket=True, + flip_aug=flip_aug, + full_bf16=full_bf16, + full_fp16=full_fp16, + gradient_accumulation_steps=gradient_accumulation_steps, + gradient_checkpointing=gradient_checkpointing, + in_json=in_json, + keep_tokens=keep_tokens, + learning_rate=learning_rate, + learning_rate_te1=learning_rate_te1 if sdxl_checkbox else None, + learning_rate_te2=learning_rate_te2 if sdxl_checkbox else None, + learning_rate_te=learning_rate_te if not sdxl_checkbox else None, + logging_dir=logging_dir, + lr_scheduler=lr_scheduler, + lr_scheduler_args=lr_scheduler_args, + lr_warmup_steps=lr_warmup_steps, + max_bucket_reso=max_bucket_reso, + max_data_loader_n_workers=max_data_loader_n_workers, + max_resolution=max_resolution, + max_timestep=max_timestep, + max_token_length=max_token_length, + max_train_epochs=max_train_epochs, + max_train_steps=max_train_steps, + mem_eff_attn=mem_eff_attn, + min_bucket_reso=min_bucket_reso, + min_snr_gamma=min_snr_gamma, + min_timestep=min_timestep, + mixed_precision=mixed_precision, + multires_noise_discount=multires_noise_discount, + multires_noise_iterations=multires_noise_iterations, + no_half_vae=no_half_vae if sdxl_checkbox else None, + noise_offset=noise_offset, + noise_offset_type=noise_offset_type, + optimizer=optimizer, + optimizer_args=optimizer_args, + output_dir=output_dir, + output_name=output_name, + persistent_data_loader_workers=persistent_data_loader_workers, + pretrained_model_name_or_path=pretrained_model_name_or_path, + random_crop=random_crop, + resume=resume, + save_every_n_epochs=save_every_n_epochs, + save_every_n_steps=save_every_n_steps, + save_last_n_steps=save_last_n_steps, + save_last_n_steps_state=save_last_n_steps_state, + save_model_as=save_model_as, + save_precision=save_precision, + save_state=save_state, + scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, + seed=seed, + shuffle_caption=shuffle_caption, + train_batch_size=train_batch_size, + train_data_dir=image_folder, + train_text_encoder=train_text_encoder, + use_wandb=use_wandb, + v2=v2, + v_parameterization=v_parameterization, + v_pred_like_loss=v_pred_like_loss, + vae_batch_size=vae_batch_size, + wandb_api_key=wandb_api_key, + weighted_captions=weighted_captions, + xformers=xformers, + ) + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + if print_only_bool: + log.warning( + "Here is the trainer command as a reference. It will not be executed:\n" + ) + print(run_cmd) + + save_to_file(run_cmd) + else: + # Saving config file for model + current_datetime = datetime.now() + formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") + file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + + log.info(f"Saving training config to {file_path}...") + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as", "headless", "print_only"], + ) + + log.info(run_cmd) + + # Run the command + executor.execute_command(run_cmd=run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f"{output_dir}/{output_name}") + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization, output_name) + + +def remove_doublequote(file_path): + if file_path != None: + file_path = file_path.replace('"', "") + + return file_path + + +def finetune_tab(headless=False): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + dummy_headless = gr.Label(value=headless, visible=False) + with gr.Tab("Training"): + gr.Markdown("Train a custom model using kohya finetune python code...") + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel(headless=headless) + + with gr.Tab("Folders"): + with gr.Row(): + train_dir = gr.Textbox( + label="Training config folder", + placeholder="folder where the training configuration files will be saved", + ) + train_dir_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + visible=(not headless), + ) + train_dir_folder.click( + get_folder_path, + outputs=train_dir, + show_progress=False, + ) + + image_folder = gr.Textbox( + label="Training Image folder", + placeholder="folder where the training images are located", + ) + image_folder_input_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + visible=(not headless), + ) + image_folder_input_folder.click( + get_folder_path, + outputs=image_folder, + show_progress=False, + ) + with gr.Row(): + output_dir = gr.Textbox( + label="Model output folder", + placeholder="folder where the model will be saved", + ) + output_dir_input_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + visible=(not headless), + ) + output_dir_input_folder.click( + get_folder_path, + outputs=output_dir, + show_progress=False, + ) + + logging_dir = gr.Textbox( + label="Logging folder", + placeholder="Optional: enable logging and output TensorBoard log to this folder", + ) + logging_dir_input_folder = gr.Button( + folder_symbol, + elem_id="open_folder_small", + visible=(not headless), + ) + logging_dir_input_folder.click( + get_folder_path, + outputs=logging_dir, + show_progress=False, + ) + with gr.Row(): + output_name = gr.Textbox( + label="Model output name", + placeholder="Name of the model to output", + value="last", + interactive=True, + ) + train_dir.change( + remove_doublequote, + inputs=[train_dir], + outputs=[train_dir], + ) + image_folder.change( + remove_doublequote, + inputs=[image_folder], + outputs=[image_folder], + ) + output_dir.change( + remove_doublequote, + inputs=[output_dir], + outputs=[output_dir], + ) + with gr.Tab("Dataset preparation"): + with gr.Row(): + max_resolution = gr.Textbox( + label="Resolution (width,height)", value="512,512" + ) + min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256") + max_bucket_reso = gr.Textbox( + label="Max bucket resolution", value="1024" + ) + batch_size = gr.Textbox(label="Batch size", value="1") + with gr.Row(): + create_caption = gr.Checkbox( + label="Generate caption metadata", value=True + ) + create_buckets = gr.Checkbox( + label="Generate image buckets metadata", value=True + ) + use_latent_files = gr.Dropdown( + label="Use latent files", + choices=[ + "No", + "Yes", + ], + value="Yes", + ) + with gr.Accordion("Advanced parameters", open=False): + with gr.Row(): + caption_metadata_filename = gr.Textbox( + label="Caption metadata filename", + value="meta_cap.json", + ) + latent_metadata_filename = gr.Textbox( + label="Latent metadata filename", value="meta_lat.json" + ) + with gr.Row(): + full_path = gr.Checkbox(label="Use full path", value=True) + weighted_captions = gr.Checkbox( + label="Weighted captions", value=False + ) + with gr.Tab("Parameters"): + + def list_presets(path): + json_files = [] + + for file in os.listdir(path): + if file.endswith(".json"): + json_files.append(os.path.splitext(file)[0]) + + user_presets_path = os.path.join(path, "user_presets") + if os.path.isdir(user_presets_path): + for file in os.listdir(user_presets_path): + if file.endswith(".json"): + preset_name = os.path.splitext(file)[0] + json_files.append(os.path.join("user_presets", preset_name)) + + return json_files + + training_preset = gr.Dropdown( + label="Presets", + choices=list_presets("./presets/finetune"), + elem_id="myDropdown", + ) + + with gr.Tab("Basic", elem_id="basic_tab"): + basic_training = BasicTraining( + learning_rate_value="1e-5", + finetuning=True, + sdxl_checkbox=source_model.sdxl_checkbox, + ) + + # Add SDXL Parameters + sdxl_params = SDXLParameters(source_model.sdxl_checkbox) + + with gr.Row(): + dataset_repeats = gr.Textbox(label="Dataset repeats", value=40) + train_text_encoder = gr.Checkbox( + label="Train text encoder", value=True + ) + + with gr.Tab("Advanced", elem_id="advanced_tab"): + with gr.Row(): + gradient_accumulation_steps = gr.Number( + label="Gradient accumulate steps", value="1" + ) + block_lr = gr.Textbox( + label="Block LR", + placeholder="(Optional)", + info="Specify the different learning rates for each U-Net block. Specify 23 values separated by commas like 1e-3,1e-3 ... 1e-3", + ) + advanced_training = AdvancedTraining(headless=headless, finetuning=True) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[ + basic_training.cache_latents + ], # Not applicable to fine_tune.py + ) + + with gr.Tab("Samples", elem_id="samples_tab"): + sample = SampleImages() + + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") + + button_stop_training = gr.Button("Stop training") + + button_print = gr.Button("Print training command") + + # Setup gradio tensorboard buttons + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=[dummy_headless, logging_dir], + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + train_dir, + image_folder, + output_dir, + logging_dir, + max_resolution, + min_bucket_reso, + max_bucket_reso, + batch_size, + advanced_training.flip_aug, + caption_metadata_filename, + latent_metadata_filename, + full_path, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + dataset_repeats, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.learning_rate_te, + basic_training.learning_rate_te1, + basic_training.learning_rate_te2, + train_text_encoder, + advanced_training.full_bf16, + create_caption, + create_buckets, + source_model.save_model_as, + basic_training.caption_extension, + advanced_training.xformers, + advanced_training.clip_skip, + advanced_training.num_processes, + advanced_training.num_machines, + advanced_training.multi_gpu, + advanced_training.gpu_ids, + advanced_training.save_state, + advanced_training.resume, + advanced_training.gradient_checkpointing, + gradient_accumulation_steps, + block_lr, + advanced_training.mem_eff_attn, + advanced_training.shuffle_caption, + output_name, + advanced_training.max_token_length, + basic_training.max_train_epochs, + advanced_training.max_data_loader_n_workers, + advanced_training.full_fp16, + advanced_training.color_aug, + source_model.model_list, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + use_latent_files, + advanced_training.keep_tokens, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.v_pred_like_loss, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + basic_training.lr_scheduler_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + weighted_captions, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + sdxl_params.sdxl_cache_text_encoder_outputs, + sdxl_params.sdxl_no_half_vae, + advanced_training.min_timestep, + advanced_training.max_timestep, + ] + + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, dummy_db_false, config.config_file_name] + + settings_list + + [training_preset], + outputs=[config.config_file_name] + settings_list + [training_preset], + show_progress=False, + ) + + # config.button_open_config.click( + # open_configuration, + # inputs=[dummy_db_true, dummy_db_false, config.config_file_name] + settings_list, + # outputs=[config.config_file_name] + settings_list, + # show_progress=False, + # ) + + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, dummy_db_false, config.config_file_name] + + settings_list + + [training_preset], + outputs=[config.config_file_name] + settings_list + [training_preset], + show_progress=False, + ) + + # config.button_load_config.click( + # open_configuration, + # inputs=[dummy_db_false, config.config_file_name] + settings_list, + # outputs=[config.config_file_name] + settings_list, + # show_progress=False, + # ) + + training_preset.input( + open_configuration, + inputs=[dummy_db_false, dummy_db_true, config.config_file_name] + + settings_list + + [training_preset], + outputs=[gr.Textbox(visible=False)] + settings_list + [training_preset], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) + + button_stop_training.click(executor.kill_command) + + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) + + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + with gr.Tab("Guides"): + gr.Markdown("This section provide Various Finetuning guides and information...") + top_level_path = "./docs/Finetuning/top_level.md" + if os.path.exists(top_level_path): + with open(os.path.join(top_level_path), "r", encoding="utf8") as file: + guides_top_level = file.read() + "\n" + gr.Markdown(guides_top_level) diff --git a/library/git_caption_gui.py b/kohya_gui/git_caption_gui.py similarity index 98% rename from library/git_caption_gui.py rename to kohya_gui/git_caption_gui.py index 83e54cfbb..f59e31253 100644 --- a/library/git_caption_gui.py +++ b/kohya_gui/git_caption_gui.py @@ -4,7 +4,7 @@ import os from .common_gui import get_folder_path, add_pre_postfix -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/group_images_gui.py b/kohya_gui/group_images_gui.py similarity index 98% rename from library/group_images_gui.py rename to kohya_gui/group_images_gui.py index 73f64f0e2..e727f8427 100644 --- a/library/group_images_gui.py +++ b/kohya_gui/group_images_gui.py @@ -4,7 +4,7 @@ from .common_gui import get_folder_path import os -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/localization.py b/kohya_gui/localization.py similarity index 100% rename from library/localization.py rename to kohya_gui/localization.py diff --git a/library/localization_ext.py b/kohya_gui/localization_ext.py similarity index 94% rename from library/localization_ext.py rename to kohya_gui/localization_ext.py index 5c33fd0db..0f7c64653 100644 --- a/library/localization_ext.py +++ b/kohya_gui/localization_ext.py @@ -1,6 +1,6 @@ import os import gradio as gr -import library.localization as localization +import kohya_gui.localization as localization def file_path(fn): @@ -30,4 +30,4 @@ def template_response(*args, **kwargs): if not hasattr(localization, 'GrRoutesTemplateResponse'): - localization.GrRoutesTemplateResponse = gr.routes.templates.TemplateResponse \ No newline at end of file + localization.GrRoutesTemplateResponse = gr.routes.templates.TemplateResponse diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py new file mode 100644 index 000000000..b67b4c737 --- /dev/null +++ b/kohya_gui/lora_gui.py @@ -0,0 +1,2019 @@ +import gradio as gr +import json +import math +import os +import lycoris +from datetime import datetime +from .common_gui import ( + get_file_path, + get_any_file_path, + get_saveasfile_path, + color_aug_changed, + run_cmd_advanced_training, + update_my_data, + check_if_model_exist, + output_message, + verify_image_folder_pattern, + SaveConfigFile, + save_to_file, + check_duplicate_filenames, +) +from .class_configuration_file import ConfigurationFile +from .class_source_model import SourceModel +from .class_basic_training import BasicTraining +from .class_advanced_training import AdvancedTraining +from .class_sdxl_parameters import SDXLParameters +from .class_folders import Folders +from .class_command_executor import CommandExecutor +from .tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from .class_sample_images import SampleImages, run_cmd_sample +from .class_lora_tab import LoRATools + +from .dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from .dataset_balancing_gui import gradio_dataset_balancing_tab + +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() + +# Setup command executor +executor = CommandExecutor() + +button_run = gr.Button("Start training", variant="primary") + +button_stop_training = gr.Button("Stop training") + +document_symbol = "\U0001F4C4" # 📄 + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + fp8_base, + full_fp16, + # no_token_padding, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + dim_from_weights, + color_aug, + flip_aug, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, + max_token_length, + max_train_epochs, + max_train_steps, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + max_grad_norm, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + LoRA_type, + factor, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, + decompose_both, + train_on_input, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + down_lr_weight, + mid_lr_weight, + up_lr_weight, + block_lr_zero_threshold, + block_dims, + block_alphas, + conv_block_dims, + conv_block_alphas, + weighted_captions, + unit, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + scale_weight_norms, + network_dropout, + rank_dropout, + module_dropout, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + full_bf16, + min_timestep, + max_timestep, + vae, + LyCORIS_preset, + debiased_estimation_loss, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get("label") == "True" else False + + if save_as_bool: + log.info("Save as...") + file_path = get_saveasfile_path(file_path) + else: + log.info("Save...") + if file_path == None or file_path == "": + file_path = get_saveasfile_path(file_path) + + # log.info(file_path) + + if file_path == None or file_path == "": + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as"], + ) + + return file_path + + +def open_configuration( + ask_for_file, + apply_preset, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + fp8_base, + full_fp16, + # no_token_padding, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + dim_from_weights, + color_aug, + flip_aug, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, + max_token_length, + max_train_epochs, + max_train_steps, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + max_grad_norm, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + LoRA_type, + factor, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, + decompose_both, + train_on_input, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + down_lr_weight, + mid_lr_weight, + up_lr_weight, + block_lr_zero_threshold, + block_dims, + block_alphas, + conv_block_dims, + conv_block_alphas, + weighted_captions, + unit, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + scale_weight_norms, + network_dropout, + rank_dropout, + module_dropout, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + full_bf16, + min_timestep, + max_timestep, + vae, + LyCORIS_preset, + debiased_estimation_loss, + training_preset, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get("label") == "True" else False + apply_preset = True if apply_preset.get("label") == "True" else False + + # Check if we are "applying" a preset or a config + if apply_preset: + if training_preset != "none": + log.info(f"Applying preset {training_preset}...") + file_path = f"./presets/lora/{training_preset}.json" + else: + # If not applying a preset, set the `training_preset` field to an empty string + # Find the index of the `training_preset` parameter using the `index()` method + training_preset_index = parameters.index(("training_preset", training_preset)) + + # Update the value of `training_preset` by directly assigning an empty string value + parameters[training_preset_index] = ("training_preset", "none") + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == "" and not file_path == None: + # Load variables from JSON file + with open(file_path, "r") as f: + my_data = json.load(f) + log.info("Loading config...") + + # Update values to fix deprecated options, set appropriate optimizer if it is set to True, etc. + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decides to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ["ask_for_file", "apply_preset", "file_path"]: + json_value = my_data.get(key) + # if isinstance(json_value, str) and json_value == '': + # # If the JSON value is an empty string, use the default value + # values.append(value) + # else: + # Otherwise, use the JSON value if not None, otherwise use the default value + values.append(json_value if json_value is not None else value) + + # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' + if my_data.get("LoRA_type", "Standard") in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/LoCon", + "LyCORIS/GLoRA", + }: + values.append(gr.Row(visible=True)) + else: + values.append(gr.Row(visible=False)) + + return tuple(values) + + +def train_model( + headless, + print_only, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + fp8_base, + full_fp16, + # no_token_padding, + stop_text_encoder_training_pct, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + dim_from_weights, + color_aug, + flip_aug, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + gradient_accumulation_steps, + mem_eff_attn, + output_name, + model_list, # Keep this. Yes, it is unused here but required given the common list used + max_token_length, + max_train_epochs, + max_train_steps, + max_data_loader_n_workers, + network_alpha, + training_comment, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + max_grad_norm, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + LoRA_type, + factor, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, + decompose_both, + train_on_input, + conv_dim, + conv_alpha, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + down_lr_weight, + mid_lr_weight, + up_lr_weight, + block_lr_zero_threshold, + block_dims, + block_alphas, + conv_block_dims, + conv_block_alphas, + weighted_captions, + unit, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + scale_weight_norms, + network_dropout, + rank_dropout, + module_dropout, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, + full_bf16, + min_timestep, + max_timestep, + vae, + LyCORIS_preset, + debiased_estimation_loss, +): + # Get list of function parameters and values + parameters = list(locals().items()) + global command_running + + print_only_bool = True if print_only.get("label") == "True" else False + log.info(f"Start training LoRA {LoRA_type} ...") + headless_bool = True if headless.get("label") == "True" else False + + if pretrained_model_name_or_path == "": + output_message( + msg="Source model information is missing", headless=headless_bool + ) + return + + if train_data_dir == "": + output_message(msg="Image folder path is missing", headless=headless_bool) + return + + # Check if there are files with the same filename but different image extension... warn the user if it is the case. + check_duplicate_filenames(train_data_dir) + + if not os.path.exists(train_data_dir): + output_message(msg="Image folder does not exist", headless=headless_bool) + return + + if not verify_image_folder_pattern(train_data_dir): + return + + if reg_data_dir != "": + if not os.path.exists(reg_data_dir): + output_message( + msg="Regularisation folder does not exist", + headless=headless_bool, + ) + return + + if not verify_image_folder_pattern(reg_data_dir): + return + + if output_dir == "": + output_message(msg="Output folder path is missing", headless=headless_bool) + return + + if int(bucket_reso_steps) < 1: + output_message( + msg="Bucket resolution steps need to be greater than 0", + headless=headless_bool, + ) + return + + if noise_offset == "": + noise_offset = 0 + + if float(noise_offset) > 1 or float(noise_offset) < 0: + output_message( + msg="Noise offset need to be a value between 0 and 1", + headless=headless_bool, + ) + return + + # if float(noise_offset) > 0 and ( + # multires_noise_iterations > 0 or multires_noise_discount > 0 + # ): + # output_message( + # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", + # title='Error', + # headless=headless_bool, + # ) + # return + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if stop_text_encoder_training_pct > 0: + output_message( + msg='Output "stop text encoder training" is not yet supported. Ignoring', + headless=headless_bool, + ) + stop_text_encoder_training_pct = 0 + + if check_if_model_exist( + output_name, output_dir, save_model_as, headless=headless_bool + ): + return + + # if optimizer == 'Adafactor' and lr_warmup != '0': + # output_message( + # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", + # title='Warning', + # headless=headless_bool, + # ) + # lr_warmup = '0' + + # If string is empty set string to 0. + if text_encoder_lr == "": + text_encoder_lr = 0 + if unet_lr == "": + unet_lr = 0 + + # Get a list of all subfolders in train_data_dir + subfolders = [ + f + for f in os.listdir(train_data_dir) + if os.path.isdir(os.path.join(train_data_dir, f)) + ] + + total_steps = 0 + + # Loop through each subfolder and extract the number of repeats + for folder in subfolders: + try: + # Extract the number of repeats from the folder name + repeats = int(folder.split("_")[0]) + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir(os.path.join(train_data_dir, folder)) + ) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) + ] + ) + + log.info(f"Folder {folder}: {num_images} images found") + + # Calculate the total number of steps for this folder + steps = repeats * num_images + + # log.info the result + log.info(f"Folder {folder}: {steps} steps") + + total_steps += steps + + except ValueError: + # Handle the case where the folder name does not contain an underscore + log.info(f"Error: '{folder}' does not contain an underscore, skipping...") + + if reg_data_dir == "": + reg_factor = 1 + else: + log.warning( + "Regularisation images are used... Will double the number of steps required..." + ) + reg_factor = 2 + + log.info(f"Total steps: {total_steps}") + log.info(f"Train batch size: {train_batch_size}") + log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}") + log.info(f"Epoch: {epoch}") + log.info(f"Regulatization factor: {reg_factor}") + + if max_train_steps == "" or max_train_steps == "0": + # calculate max_train_steps + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + / int(gradient_accumulation_steps) + * int(epoch) + * int(reg_factor) + ) + ) + log.info( + f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}" + ) + + # calculate stop encoder training + if stop_text_encoder_training_pct == None: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + ) + log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + log.info(f"lr_warmup_steps = {lr_warmup_steps}") + + run_cmd = "accelerate launch" + + run_cmd += run_cmd_advanced_training( + num_processes=num_processes, + num_machines=num_machines, + multi_gpu=multi_gpu, + gpu_ids=gpu_ids, + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + + if sdxl: + run_cmd += f' "./sdxl_train_network.py"' + else: + run_cmd += f' "./train_network.py"' + + if LoRA_type == "LyCORIS/Diag-OFT": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "constrain={constrain}" "rescaled={rescaled}" "algo=diag-oft" ' + + if LoRA_type == "LyCORIS/DyLoRA": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "use_tucker={use_tucker}" "block_size={unit}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "algo=dylora" "train_norm={train_norm}"' + + if LoRA_type == "LyCORIS/GLoRA": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "rank_dropout_scale={rank_dropout_scale}" "algo=glora" "train_norm={train_norm}"' + + if LoRA_type == "LyCORIS/iA3": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "train_on_input={train_on_input}" "algo=ia3"' + + if LoRA_type == "LoCon" or LoRA_type == "LyCORIS/LoCon": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=locon" "train_norm={train_norm}"' + + if LoRA_type == "LyCORIS/LoHa": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=loha" "train_norm={train_norm}"' + + if LoRA_type == "LyCORIS/LoKr": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "factor={factor}" "use_cp={use_cp}" "use_scalar={use_scalar}" "decompose_both={decompose_both}" "rank_dropout_scale={rank_dropout_scale}" "algo=lokr" "train_norm={train_norm}"' + + if LoRA_type == "LyCORIS/Native Fine-Tuning": + network_module = "lycoris.kohya" + network_args = f' "preset={LyCORIS_preset}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=full" "train_norm={train_norm}"' + + if LoRA_type in ["Kohya LoCon", "Standard"]: + kohya_lora_var_list = [ + "down_lr_weight", + "mid_lr_weight", + "up_lr_weight", + "block_lr_zero_threshold", + "block_dims", + "block_alphas", + "conv_block_dims", + "conv_block_alphas", + "rank_dropout", + "module_dropout", + ] + + network_module = "networks.lora" + kohya_lora_vars = { + key: value + for key, value in vars().items() + if key in kohya_lora_var_list and value + } + + network_args = "" + if LoRA_type == "Kohya LoCon": + network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f' {key}="{value}"' + + if LoRA_type in [ + "LoRA-FA", + ]: + kohya_lora_var_list = [ + "down_lr_weight", + "mid_lr_weight", + "up_lr_weight", + "block_lr_zero_threshold", + "block_dims", + "block_alphas", + "conv_block_dims", + "conv_block_alphas", + "rank_dropout", + "module_dropout", + ] + + network_module = "networks.lora_fa" + kohya_lora_vars = { + key: value + for key, value in vars().items() + if key in kohya_lora_var_list and value + } + + network_args = "" + if LoRA_type == "Kohya LoCon": + network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f' {key}="{value}"' + + if LoRA_type in ["Kohya DyLoRA"]: + kohya_lora_var_list = [ + "conv_dim", + "conv_alpha", + "down_lr_weight", + "mid_lr_weight", + "up_lr_weight", + "block_lr_zero_threshold", + "block_dims", + "block_alphas", + "conv_block_dims", + "conv_block_alphas", + "rank_dropout", + "module_dropout", + "unit", + ] + + network_module = "networks.dylora" + kohya_lora_vars = { + key: value + for key, value in vars().items() + if key in kohya_lora_var_list and value + } + + network_args = "" + + for key, value in kohya_lora_vars.items(): + if value: + network_args += f' {key}="{value}"' + + network_train_text_encoder_only = False + network_train_unet_only = False + + # Convert learning rates to float once and store the result for re-use + if text_encoder_lr is None: + output_message( + msg="Please input valid Text Encoder learning rate (between 0 and 1)", headless=headless_bool + ) + return + if unet_lr is None: + output_message( + msg="Please input valid Unet learning rate (between 0 and 1)", headless=headless_bool + ) + return + text_encoder_lr_float = float(text_encoder_lr) + unet_lr_float = float(unet_lr) + + + + # Determine the training configuration based on learning rate values + if text_encoder_lr_float == 0 and unet_lr_float == 0: + if float(learning_rate) == 0: + output_message( + msg="Please input learning rate values.", headless=headless_bool + ) + return + elif text_encoder_lr_float != 0 and unet_lr_float == 0: + network_train_text_encoder_only = True + elif text_encoder_lr_float == 0 and unet_lr_float != 0: + network_train_unet_only = True + # If both learning rates are non-zero, no specific flags need to be set + + run_cmd += run_cmd_advanced_training( + adaptive_noise_scale=adaptive_noise_scale, + additional_parameters=additional_parameters, + bucket_no_upscale=bucket_no_upscale, + bucket_reso_steps=bucket_reso_steps, + cache_latents=cache_latents, + cache_latents_to_disk=cache_latents_to_disk, + cache_text_encoder_outputs=True if sdxl and sdxl_cache_text_encoder_outputs else None, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + caption_extension=caption_extension, + clip_skip=clip_skip, + color_aug=color_aug, + debiased_estimation_loss=debiased_estimation_loss, + dim_from_weights=dim_from_weights, + enable_bucket=enable_bucket, + epoch=epoch, + flip_aug=flip_aug, + fp8_base=fp8_base, + full_bf16=full_bf16, + full_fp16=full_fp16, + gradient_accumulation_steps=gradient_accumulation_steps, + gradient_checkpointing=gradient_checkpointing, + keep_tokens=keep_tokens, + learning_rate=learning_rate, + logging_dir=logging_dir, + lora_network_weights=lora_network_weights, + lr_scheduler=lr_scheduler, + lr_scheduler_args=lr_scheduler_args, + lr_scheduler_num_cycles=lr_scheduler_num_cycles, + lr_scheduler_power=lr_scheduler_power, + lr_warmup_steps=lr_warmup_steps, + max_bucket_reso=max_bucket_reso, + max_data_loader_n_workers=max_data_loader_n_workers, + max_grad_norm=max_grad_norm, + max_resolution=max_resolution, + max_timestep=max_timestep, + max_token_length=max_token_length, + max_train_epochs=max_train_epochs, + max_train_steps=max_train_steps, + mem_eff_attn=mem_eff_attn, + min_bucket_reso=min_bucket_reso, + min_snr_gamma=min_snr_gamma, + min_timestep=min_timestep, + mixed_precision=mixed_precision, + multires_noise_discount=multires_noise_discount, + multires_noise_iterations=multires_noise_iterations, + network_alpha=network_alpha, + network_args=network_args, + network_dim=network_dim, + network_dropout=network_dropout, + network_module=network_module, + network_train_unet_only=network_train_unet_only, + network_train_text_encoder_only=network_train_text_encoder_only, + no_half_vae=True if sdxl and sdxl_no_half_vae else None, + # no_token_padding=no_token_padding, + noise_offset=noise_offset, + noise_offset_type=noise_offset_type, + optimizer=optimizer, + optimizer_args=optimizer_args, + output_dir=output_dir, + output_name=output_name, + persistent_data_loader_workers=persistent_data_loader_workers, + pretrained_model_name_or_path=pretrained_model_name_or_path, + prior_loss_weight=prior_loss_weight, + random_crop=random_crop, + reg_data_dir=reg_data_dir, + resume=resume, + save_every_n_epochs=save_every_n_epochs, + save_every_n_steps=save_every_n_steps, + save_last_n_steps=save_last_n_steps, + save_last_n_steps_state=save_last_n_steps_state, + save_model_as=save_model_as, + save_precision=save_precision, + save_state=save_state, + scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, + scale_weight_norms=scale_weight_norms, + seed=seed, + shuffle_caption=shuffle_caption, + stop_text_encoder_training=stop_text_encoder_training, + text_encoder_lr=text_encoder_lr, + train_batch_size=train_batch_size, + train_data_dir=train_data_dir, + training_comment=training_comment, + unet_lr=unet_lr, + use_wandb=use_wandb, + v2=v2, + v_parameterization=v_parameterization, + v_pred_like_loss=v_pred_like_loss, + vae=vae, + vae_batch_size=vae_batch_size, + wandb_api_key=wandb_api_key, + weighted_captions=weighted_captions, + xformers=xformers, + ) + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + if print_only_bool: + log.warning( + "Here is the trainer command as a reference. It will not be executed:\n" + ) + print(run_cmd) + + save_to_file(run_cmd) + else: + # Saving config file for model + current_datetime = datetime.now() + formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") + file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + + log.info(f"Saving training config to {file_path}...") + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as", "headless", "print_only"], + ) + + log.info(run_cmd) + # Run the command + executor.execute_command(run_cmd=run_cmd) + + # # check if output_dir/last is a folder... therefore it is a diffuser model + # last_dir = pathlib.Path(f'{output_dir}/{output_name}') + + # if not last_dir.is_dir(): + # # Copy inference model for v2 if required + # save_inference_file( + # output_dir, v2, v_parameterization, output_name + # ) + + +def lora_tab( + train_data_dir_input=gr.Textbox(), + reg_data_dir_input=gr.Textbox(), + output_dir_input=gr.Textbox(), + logging_dir_input=gr.Textbox(), + headless=False, +): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + dummy_headless = gr.Label(value=headless, visible=False) + + with gr.Tab("Training"): + gr.Markdown( + "Train a custom model using kohya train network LoRA python code..." + ) + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel( + save_model_as_choices=[ + "ckpt", + "safetensors", + ], + headless=headless, + ) + + with gr.Tab("Folders"): + folders = Folders(headless=headless) + + with gr.Tab("Parameters"): + + def list_presets(path): + json_files = [] + + # Insert an empty string at the beginning + json_files.insert(0, "none") + + for file in os.listdir(path): + if file.endswith(".json"): + json_files.append(os.path.splitext(file)[0]) + + user_presets_path = os.path.join(path, "user_presets") + if os.path.isdir(user_presets_path): + for file in os.listdir(user_presets_path): + if file.endswith(".json"): + preset_name = os.path.splitext(file)[0] + json_files.append(os.path.join("user_presets", preset_name)) + + return json_files + + training_preset = gr.Dropdown( + label="Presets", + choices=list_presets("./presets/lora"), + elem_id="myDropdown", + value="none" + ) + + with gr.Tab("Basic", elem_id="basic_tab"): + with gr.Row(): + LoRA_type = gr.Dropdown( + label="LoRA type", + choices=[ + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/DyLoRA", + "LyCORIS/iA3", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + "Standard", + ], + value="Standard", + ) + LyCORIS_preset = gr.Dropdown( + label="LyCORIS Preset", + choices=[ + "attn-mlp", + "attn-only", + "full", + "full-lin", + "unet-transformer-only", + "unet-convblock-only", + ], + value="full", + visible=False, + interactive=True + # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" + ) + with gr.Group(): + with gr.Row(): + lora_network_weights = gr.Textbox( + label="LoRA network weights", + placeholder="(Optional)", + info="Path to an existing LoRA network weights to resume training from", + ) + lora_network_weights_file = gr.Button( + document_symbol, + elem_id="open_folder_small", + visible=(not headless), + ) + lora_network_weights_file.click( + get_any_file_path, + inputs=[lora_network_weights], + outputs=lora_network_weights, + show_progress=False, + ) + dim_from_weights = gr.Checkbox( + label="DIM from weights", + value=False, + info="Automatically determine the dim(rank) from the weight file.", + ) + basic_training = BasicTraining( + learning_rate_value="0.0001", + lr_scheduler_value="cosine", + lr_warmup_value="10", + sdxl_checkbox=source_model.sdxl_checkbox, + ) + + with gr.Row(): + text_encoder_lr = gr.Number( + label="Text Encoder learning rate", + value="0.0001", + info="Optional", + minimum=0, + maximum=1, + ) + + unet_lr = gr.Number( + label="Unet learning rate", + value="0.0001", + info="Optional", + minimum=0, + maximum=1, + ) + + # Add SDXL Parameters + sdxl_params = SDXLParameters(source_model.sdxl_checkbox) + + with gr.Row(): + factor = gr.Slider( + label="LoKr factor", + value=-1, + minimum=-1, + maximum=64, + step=1, + visible=False, + ) + use_cp = gr.Checkbox( + value=False, + label="Use CP decomposition", + info="A two-step approach utilizing tensor decomposition and fine-tuning to accelerate convolution layers in large neural networks, resulting in significant CPU speedups with minor accuracy drops.", + visible=False, + ) + use_tucker = gr.Checkbox( + value=False, + label="Use Tucker decomposition", + info="Efficiently decompose tensor shapes, resulting in a sequence of convolution layers with varying dimensions and Hadamard product implementation through multiplication of two distinct tensors.", + visible=False, + ) + use_scalar = gr.Checkbox( + value=False, + label="Use Scalar", + info="Train an additional scalar in front of the weight difference, use a different weight initialization strategy.", + visible=False, + ) + rank_dropout_scale = gr.Checkbox( + value=False, + label="Rank Dropout Scale", + info="Adjusts the scale of the rank dropout to maintain the average dropout rate, ensuring more consistent regularization across different layers.", + visible=False, + ) + constrain = gr.Number( + value="0.0", + label="Constrain OFT", + info="Limits the norm of the oft_blocks, ensuring that their magnitude does not exceed a specified threshold, thus controlling the extent of the transformation applied.", + visible=False, + ) + rescaled = gr.Checkbox( + value=False, + label="Rescaled OFT", + info="applies an additional scaling factor to the oft_blocks, allowing for further adjustment of their impact on the model's transformations.", + visible=False, + ) + train_norm = gr.Checkbox( + value=False, + label="Train Norm", + info="Selects trainable layers in a network, but trains normalization layers identically across methods as they lack matrix decomposition.", + visible=False, + ) + decompose_both = gr.Checkbox( + value=False, + label="LoKr decompose both", + info=" Controls whether both input and output dimensions of the layer's weights are decomposed into smaller matrices for reparameterization.", + visible=False, + ) + train_on_input = gr.Checkbox( + value=True, + label="iA3 train on input", + info="Set if we change the information going into the system (True) or the information coming out of it (False).", + visible=False, + ) + + with gr.Row() as network_row: + network_dim = gr.Slider( + minimum=1, + maximum=512, + label="Network Rank (Dimension)", + value=8, + step=1, + interactive=True, + ) + network_alpha = gr.Slider( + minimum=0.1, + maximum=1024, + label="Network Alpha", + value=1, + step=0.1, + interactive=True, + info="alpha for LoRA weight scaling", + ) + with gr.Row(visible=False) as convolution_row: + # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) + conv_dim = gr.Slider( + minimum=0, + maximum=512, + value=1, + step=1, + label="Convolution Rank (Dimension)", + ) + conv_alpha = gr.Slider( + minimum=0, + maximum=512, + value=1, + step=1, + label="Convolution Alpha", + ) + with gr.Row(): + scale_weight_norms = gr.Slider( + label="Scale weight norms", + value=0, + minimum=0, + maximum=10, + step=0.01, + info="Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR #545 on kohya_ss/sd_scripts repo for details. Recommended setting: 1. Higher is weaker, lower is stronger.", + interactive=True, + ) + network_dropout = gr.Slider( + label="Network dropout", + value=0, + minimum=0, + maximum=1, + step=0.01, + info="Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5", + ) + rank_dropout = gr.Slider( + label="Rank dropout", + value=0, + minimum=0, + maximum=1, + step=0.01, + info="can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", + ) + module_dropout = gr.Slider( + label="Module dropout", + value=0.0, + minimum=0.0, + maximum=1.0, + step=0.01, + info="can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", + ) + with gr.Row(visible=False) as kohya_dylora: + unit = gr.Slider( + minimum=1, + maximum=64, + label="DyLoRA Unit / Block size", + value=1, + step=1, + interactive=True, + ) + + # Show or hide LoCon conv settings depending on LoRA type selection + def update_LoRA_settings( + LoRA_type, + conv_dim, + network_dim, + ): + log.info("LoRA type changed...") + + lora_settings_config = { + "network_row": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "Standard", + }, + }, + }, + "convolution_row": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/LoCon", + "LyCORIS/GLoRA", + }, + }, + }, + "kohya_advanced_lora": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + }, + }, + }, + "kohya_dylora": { + "gr_type": gr.Row, + "update_params": { + "visible": LoRA_type + in { + "Kohya DyLoRA", + "LyCORIS/DyLoRA", + }, + }, + }, + "lora_network_weights": { + "gr_type": gr.Textbox, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + }, + }, + }, + "lora_network_weights_file": { + "gr_type": gr.Button, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + }, + }, + }, + "dim_from_weights": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "Standard", + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + } + }, + }, + "factor": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/LoKr", + }, + }, + }, + "conv_dim": { + "gr_type": gr.Slider, + "update_params": { + "maximum": 100000 + if LoRA_type + in { + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Diag-OFT", + } + else 512, + "value": conv_dim, # if conv_dim > 512 else conv_dim, + }, + }, + "network_dim": { + "gr_type": gr.Slider, + "update_params": { + "maximum": 100000 + if LoRA_type + in { + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Diag-OFT", + } + else 512, + "value": network_dim, # if network_dim > 512 else network_dim, + }, + }, + "use_cp": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/LoKr", + }, + }, + }, + "use_tucker": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/Native Fine-Tuning", + }, + }, + }, + "use_scalar": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/Diag-OFT", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, + }, + }, + "rank_dropout_scale": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, + }, + }, + "constrain": { + "gr_type": gr.Number, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/Diag-OFT", + }, + }, + }, + "rescaled": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/Diag-OFT", + }, + }, + }, + "train_norm": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/DyLoRA", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, + }, + }, + "decompose_both": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type in {"LyCORIS/LoKr"}, + }, + }, + "train_on_input": { + "gr_type": gr.Checkbox, + "update_params": { + "visible": LoRA_type in {"LyCORIS/iA3"}, + }, + }, + "scale_weight_norms": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoHa", + "LyCORIS/LoCon", + "LyCORIS/LoKr", + "Standard", + }, + }, + }, + "network_dropout": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Diag-OFT", + "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + "Standard", + }, + }, + }, + "rank_dropout": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "Kohya DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKR", + "Kohya LoCon", + "LoRA-FA", + "LyCORIS/Native Fine-Tuning", + "Standard", + }, + }, + }, + "module_dropout": { + "gr_type": gr.Slider, + "update_params": { + "visible": LoRA_type + in { + "LoCon", + "LyCORIS/Diag-OFT", + "Kohya DyLoRA", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKR", + "Kohya LoCon", + "LyCORIS/Native Fine-Tuning", + "LoRA-FA", + "Standard", + }, + }, + }, + "LyCORIS_preset": { + "gr_type": gr.Dropdown, + "update_params": { + "visible": LoRA_type + in { + "LyCORIS/DyLoRA", + "LyCORIS/iA3", + "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", + "LyCORIS/LoCon", + "LyCORIS/LoHa", + "LyCORIS/LoKr", + "LyCORIS/Native Fine-Tuning", + }, + }, + }, + } + + results = [] + for attr, settings in lora_settings_config.items(): + update_params = settings["update_params"] + + results.append(settings["gr_type"](**update_params)) + + return tuple(results) + + with gr.Tab("Advanced", elem_id="advanced_tab"): + # with gr.Accordion('Advanced Configuration', open=False): + with gr.Row(visible=True) as kohya_advanced_lora: + with gr.Tab(label="Weights"): + with gr.Row(visible=True): + down_lr_weight = gr.Textbox( + label="Down LR weights", + placeholder="(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1", + info="Specify the learning rate weight of the down blocks of U-Net.", + ) + mid_lr_weight = gr.Textbox( + label="Mid LR weights", + placeholder="(Optional) eg: 0.5", + info="Specify the learning rate weight of the mid block of U-Net.", + ) + up_lr_weight = gr.Textbox( + label="Up LR weights", + placeholder="(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1", + info="Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.", + ) + block_lr_zero_threshold = gr.Textbox( + label="Blocks LR zero threshold", + placeholder="(Optional) eg: 0.1", + info="If the weight is not more than this value, the LoRA module is not created. The default is 0.", + ) + with gr.Tab(label="Blocks"): + with gr.Row(visible=True): + block_dims = gr.Textbox( + label="Block dims", + placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", + info="Specify the dim (rank) of each block. Specify 25 numbers.", + ) + block_alphas = gr.Textbox( + label="Block alphas", + placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", + info="Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.", + ) + with gr.Tab(label="Conv"): + with gr.Row(visible=True): + conv_block_dims = gr.Textbox( + label="Conv dims", + placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", + info="Extend LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.", + ) + conv_block_alphas = gr.Textbox( + label="Conv alphas", + placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", + info="Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.", + ) + advanced_training = AdvancedTraining( + headless=headless, training_type="lora" + ) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) + + with gr.Tab("Samples", elem_id="samples_tab"): + sample = SampleImages() + + LoRA_type.change( + update_LoRA_settings, + inputs=[ + LoRA_type, + conv_dim, + network_dim, + ], + outputs=[ + network_row, + convolution_row, + kohya_advanced_lora, + kohya_dylora, + lora_network_weights, + lora_network_weights_file, + dim_from_weights, + factor, + conv_dim, + network_dim, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, + decompose_both, + train_on_input, + scale_weight_norms, + network_dropout, + rank_dropout, + module_dropout, + LyCORIS_preset, + ], + ) + + with gr.Tab("Dataset Preparation"): + gr.Markdown( + "This section provide Dreambooth tools to help setup your dataset..." + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) + gradio_dataset_balancing_tab(headless=headless) + + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") + + button_stop_training = gr.Button("Stop training") + + button_print = gr.Button("Print training command") + + # Setup gradio tensorboard buttons + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=[dummy_headless, folders.logging_dir], + show_progress=False, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.fp8_base, + advanced_training.full_fp16, + # advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + basic_training.min_bucket_reso, + basic_training.max_bucket_reso, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + text_encoder_lr, + unet_lr, + network_dim, + lora_network_weights, + dim_from_weights, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.num_processes, + advanced_training.num_machines, + advanced_training.multi_gpu, + advanced_training.gpu_ids, + advanced_training.gradient_accumulation_steps, + advanced_training.mem_eff_attn, + folders.output_name, + source_model.model_list, + advanced_training.max_token_length, + basic_training.max_train_epochs, + basic_training.max_train_steps, + advanced_training.max_data_loader_n_workers, + network_alpha, + folders.training_comment, + advanced_training.keep_tokens, + basic_training.lr_scheduler_num_cycles, + basic_training.lr_scheduler_power, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.v_pred_like_loss, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + basic_training.lr_scheduler_args, + basic_training.max_grad_norm, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + LoRA_type, + factor, + use_cp, + use_tucker, + use_scalar, + rank_dropout_scale, + constrain, + rescaled, + train_norm, + decompose_both, + train_on_input, + conv_dim, + conv_alpha, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + down_lr_weight, + mid_lr_weight, + up_lr_weight, + block_lr_zero_threshold, + block_dims, + block_alphas, + conv_block_dims, + conv_block_alphas, + advanced_training.weighted_captions, + unit, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + scale_weight_norms, + network_dropout, + rank_dropout, + module_dropout, + sdxl_params.sdxl_cache_text_encoder_outputs, + sdxl_params.sdxl_no_half_vae, + advanced_training.full_bf16, + advanced_training.min_timestep, + advanced_training.max_timestep, + advanced_training.vae, + LyCORIS_preset, + advanced_training.debiased_estimation_loss, + ] + + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, dummy_db_false, config.config_file_name] + + settings_list + + [training_preset], + outputs=[config.config_file_name] + + settings_list + + [training_preset, convolution_row], + show_progress=False, + ) + + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, dummy_db_false, config.config_file_name] + + settings_list + + [training_preset], + outputs=[config.config_file_name] + + settings_list + + [training_preset, convolution_row], + show_progress=False, + ) + + training_preset.input( + open_configuration, + inputs=[dummy_db_false, dummy_db_true, config.config_file_name] + + settings_list + + [training_preset], + outputs=[gr.Textbox(visible=False)] + settings_list + [training_preset, convolution_row], + show_progress=False, + ) + + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) + + button_stop_training.click(executor.kill_command) + + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) + + with gr.Tab("Tools"): + lora_tools = LoRATools(folders=folders, headless=headless) + + with gr.Tab("Guides"): + gr.Markdown("This section provide Various LoRA guides and information...") + if os.path.exists("./docs/LoRA/top_level.md"): + with open( + os.path.join("./docs/LoRA/top_level.md"), "r", encoding="utf8" + ) as file: + guides_top_level = file.read() + "\n" + gr.Markdown(guides_top_level) + + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) diff --git a/library/manual_caption_gui.py b/kohya_gui/manual_caption_gui.py similarity index 99% rename from library/manual_caption_gui.py rename to kohya_gui/manual_caption_gui.py index 0f7d5b008..c7d108b12 100644 --- a/library/manual_caption_gui.py +++ b/kohya_gui/manual_caption_gui.py @@ -5,7 +5,7 @@ import os import re -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/merge_lora_gui.py b/kohya_gui/merge_lora_gui.py similarity index 99% rename from library/merge_lora_gui.py rename to kohya_gui/merge_lora_gui.py index db2346f57..76004760f 100644 --- a/library/merge_lora_gui.py +++ b/kohya_gui/merge_lora_gui.py @@ -8,7 +8,7 @@ # Local module imports from .common_gui import get_saveasfilename_path, get_file_path -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/merge_lycoris_gui.py b/kohya_gui/merge_lycoris_gui.py similarity index 99% rename from library/merge_lycoris_gui.py rename to kohya_gui/merge_lycoris_gui.py index bb084ffdc..d191ee995 100644 --- a/library/merge_lycoris_gui.py +++ b/kohya_gui/merge_lycoris_gui.py @@ -7,7 +7,7 @@ get_file_path, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/resize_lora_gui.py b/kohya_gui/resize_lora_gui.py similarity index 99% rename from library/resize_lora_gui.py rename to kohya_gui/resize_lora_gui.py index 6ad481e3c..32c5335c3 100644 --- a/library/resize_lora_gui.py +++ b/kohya_gui/resize_lora_gui.py @@ -4,7 +4,7 @@ import os from .common_gui import get_saveasfilename_path, get_file_path -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/svd_merge_lora_gui.py b/kohya_gui/svd_merge_lora_gui.py similarity index 99% rename from library/svd_merge_lora_gui.py rename to kohya_gui/svd_merge_lora_gui.py index 27d670328..603b26841 100644 --- a/library/svd_merge_lora_gui.py +++ b/kohya_gui/svd_merge_lora_gui.py @@ -8,7 +8,7 @@ get_file_path, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/tensorboard_gui.py b/kohya_gui/tensorboard_gui.py similarity index 98% rename from library/tensorboard_gui.py rename to kohya_gui/tensorboard_gui.py index 41fe98f92..cf8dcf110 100644 --- a/library/tensorboard_gui.py +++ b/kohya_gui/tensorboard_gui.py @@ -4,7 +4,7 @@ import subprocess import time import webbrowser -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py new file mode 100644 index 000000000..6b9a0bdfa --- /dev/null +++ b/kohya_gui/textual_inversion_gui.py @@ -0,0 +1,961 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import json +import math +import os +import subprocess +import pathlib +from datetime import datetime +from .common_gui import ( + get_file_path, + get_saveasfile_path, + color_aug_changed, + save_inference_file, + run_cmd_advanced_training, + update_my_data, + check_if_model_exist, + output_message, + verify_image_folder_pattern, + SaveConfigFile, + save_to_file, +) +from .class_configuration_file import ConfigurationFile +from .class_source_model import SourceModel +from .class_basic_training import BasicTraining +from .class_advanced_training import AdvancedTraining +from .class_folders import Folders +from .class_sdxl_parameters import SDXLParameters +from .class_command_executor import CommandExecutor +from .tensorboard_gui import ( + gradio_tensorboard, + start_tensorboard, + stop_tensorboard, +) +from .dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) +from .dataset_balancing_gui import gradio_dataset_balancing_tab +from .class_sample_images import SampleImages, run_cmd_sample + +from .custom_logging import setup_logging + +# Set up logging +log = setup_logging() + +# Setup command executor +executor = CommandExecutor() + + +def save_configuration( + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + min_timestep, + max_timestep, + sdxl_no_half_vae, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + original_file_path = file_path + + save_as_bool = True if save_as.get("label") == "True" else False + + if save_as_bool: + log.info("Save as...") + file_path = get_saveasfile_path(file_path) + else: + log.info("Save...") + if file_path == None or file_path == "": + file_path = get_saveasfile_path(file_path) + + # log.info(file_path) + + if file_path == None or file_path == "": + return original_file_path # In case a file_path was provided and the user decide to cancel the open action + + # Extract the destination directory from the file path + destination_directory = os.path.dirname(file_path) + + # Create the destination directory if it doesn't exist + if not os.path.exists(destination_directory): + os.makedirs(destination_directory) + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as"], + ) + + return file_path + + +def open_configuration( + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + min_timestep, + max_timestep, + sdxl_no_half_vae, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + ask_for_file = True if ask_for_file.get("label") == "True" else False + + original_file_path = file_path + + if ask_for_file: + file_path = get_file_path(file_path) + + if not file_path == "" and not file_path == None: + # load variables from JSON file + with open(file_path, "r") as f: + my_data = json.load(f) + log.info("Loading config...") + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) + else: + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action + my_data = {} + + values = [file_path] + for key, value in parameters: + # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found + if not key in ["ask_for_file", "file_path"]: + values.append(my_data.get(key, value)) + return tuple(values) + + +def train_model( + headless, + print_only, + pretrained_model_name_or_path, + v2, + v_parameterization, + sdxl, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + cache_latents_to_disk, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + min_bucket_reso, + max_bucket_reso, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + num_processes, + num_machines, + multi_gpu, + gpu_ids, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + token_string, + init_word, + num_vectors_per_token, + max_train_steps, + weights, + template, + keep_tokens, + lr_scheduler_num_cycles, + lr_scheduler_power, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + v_pred_like_loss, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + lr_scheduler_args, + noise_offset_type, + noise_offset, + adaptive_noise_scale, + multires_noise_iterations, + multires_noise_discount, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, + save_every_n_steps, + save_last_n_steps, + save_last_n_steps_state, + use_wandb, + wandb_api_key, + scale_v_pred_loss_like_noise_pred, + min_timestep, + max_timestep, + sdxl_no_half_vae, +): + # Get list of function parameters and values + parameters = list(locals().items()) + + print_only_bool = True if print_only.get("label") == "True" else False + log.info(f"Start training TI...") + + headless_bool = True if headless.get("label") == "True" else False + + if pretrained_model_name_or_path == "": + output_message( + msg="Source model information is missing", headless=headless_bool + ) + return + + if train_data_dir == "": + output_message(msg="Image folder path is missing", headless=headless_bool) + return + + if not os.path.exists(train_data_dir): + output_message(msg="Image folder does not exist", headless=headless_bool) + return + + if not verify_image_folder_pattern(train_data_dir): + return + + if reg_data_dir != "": + if not os.path.exists(reg_data_dir): + output_message( + msg="Regularisation folder does not exist", + headless=headless_bool, + ) + return + + if not verify_image_folder_pattern(reg_data_dir): + return + + if output_dir == "": + output_message(msg="Output folder path is missing", headless=headless_bool) + return + + if token_string == "": + output_message(msg="Token string is missing", headless=headless_bool) + return + + if init_word == "": + output_message(msg="Init word is missing", headless=headless_bool) + return + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): + return + + # if float(noise_offset) > 0 and ( + # multires_noise_iterations > 0 or multires_noise_discount > 0 + # ): + # output_message( + # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", + # title='Error', + # headless=headless_bool, + # ) + # return + + # if optimizer == 'Adafactor' and lr_warmup != '0': + # output_message( + # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", + # title='Warning', + # headless=headless_bool, + # ) + # lr_warmup = '0' + + # Get a list of all subfolders in train_data_dir + subfolders = [ + f + for f in os.listdir(train_data_dir) + if os.path.isdir(os.path.join(train_data_dir, f)) + ] + + total_steps = 0 + + # Loop through each subfolder and extract the number of repeats + for folder in subfolders: + # Extract the number of repeats from the folder name + repeats = int(folder.split("_")[0]) + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir(os.path.join(train_data_dir, folder)) + ) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) + ] + ) + + # Calculate the total number of steps for this folder + steps = repeats * num_images + total_steps += steps + + # Print the result + log.info(f"Folder {folder}: {steps} steps") + + # Print the result + # log.info(f"{total_steps} total steps") + + if reg_data_dir == "": + reg_factor = 1 + else: + log.info( + "Regularisation images are used... Will double the number of steps required..." + ) + reg_factor = 2 + + # calculate max_train_steps + if max_train_steps == "" or max_train_steps == "0": + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + / int(gradient_accumulation_steps) + * int(epoch) + * int(reg_factor) + ) + ) + else: + max_train_steps = int(max_train_steps) + + log.info(f"max_train_steps = {max_train_steps}") + + # calculate stop encoder training + if stop_text_encoder_training_pct == None: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) + ) + log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") + + lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) + log.info(f"lr_warmup_steps = {lr_warmup_steps}") + + run_cmd = "accelerate launch" + + run_cmd += run_cmd_advanced_training( + num_processes=num_processes, + num_machines=num_machines, + multi_gpu=multi_gpu, + gpu_ids=gpu_ids, + num_cpu_threads_per_process=num_cpu_threads_per_process, + ) + + if sdxl: + run_cmd += f' "./sdxl_train_textual_inversion.py"' + else: + run_cmd += f' "./train_textual_inversion.py"' + + run_cmd += run_cmd_advanced_training( + adaptive_noise_scale=adaptive_noise_scale, + additional_parameters=additional_parameters, + bucket_no_upscale=bucket_no_upscale, + bucket_reso_steps=bucket_reso_steps, + cache_latents=cache_latents, + cache_latents_to_disk=cache_latents_to_disk, + caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, + caption_dropout_rate=caption_dropout_rate, + caption_extension=caption_extension, + clip_skip=clip_skip, + color_aug=color_aug, + enable_bucket=enable_bucket, + epoch=epoch, + flip_aug=flip_aug, + full_fp16=full_fp16, + gradient_accumulation_steps=gradient_accumulation_steps, + gradient_checkpointing=gradient_checkpointing, + keep_tokens=keep_tokens, + learning_rate=learning_rate, + logging_dir=logging_dir, + lr_scheduler=lr_scheduler, + lr_scheduler_args=lr_scheduler_args, + lr_scheduler_num_cycles=lr_scheduler_num_cycles, + lr_scheduler_power=lr_scheduler_power, + lr_warmup_steps=lr_warmup_steps, + max_bucket_reso=max_bucket_reso, + max_data_loader_n_workers=max_data_loader_n_workers, + max_resolution=max_resolution, + max_timestep=max_timestep, + max_token_length=max_token_length, + max_train_epochs=max_train_epochs, + max_train_steps=max_train_steps, + mem_eff_attn=mem_eff_attn, + min_bucket_reso=min_bucket_reso, + min_snr_gamma=min_snr_gamma, + min_timestep=min_timestep, + mixed_precision=mixed_precision, + multires_noise_discount=multires_noise_discount, + multires_noise_iterations=multires_noise_iterations, + no_half_vae=True if sdxl and sdxl_no_half_vae else None, + no_token_padding=no_token_padding, + noise_offset=noise_offset, + noise_offset_type=noise_offset_type, + optimizer=optimizer, + optimizer_args=optimizer_args, + output_dir=output_dir, + output_name=output_name, + persistent_data_loader_workers=persistent_data_loader_workers, + pretrained_model_name_or_path=pretrained_model_name_or_path, + prior_loss_weight=prior_loss_weight, + random_crop=random_crop, + reg_data_dir=reg_data_dir, + resume=resume, + save_every_n_epochs=save_every_n_epochs, + save_every_n_steps=save_every_n_steps, + save_last_n_steps=save_last_n_steps, + save_last_n_steps_state=save_last_n_steps_state, + save_model_as=save_model_as, + save_precision=save_precision, + save_state=save_state, + scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, + seed=seed, + shuffle_caption=shuffle_caption, + stop_text_encoder_training=stop_text_encoder_training, + train_batch_size=train_batch_size, + train_data_dir=train_data_dir, + use_wandb=use_wandb, + v2=v2, + v_parameterization=v_parameterization, + v_pred_like_loss=v_pred_like_loss, + vae=vae, + vae_batch_size=vae_batch_size, + wandb_api_key=wandb_api_key, + xformers=xformers, + ) + run_cmd += f' --token_string="{token_string}"' + run_cmd += f' --init_word="{init_word}"' + run_cmd += f" --num_vectors_per_token={num_vectors_per_token}" + if not weights == "": + run_cmd += f' --weights="{weights}"' + if template == "object template": + run_cmd += f" --use_object_template" + elif template == "style template": + run_cmd += f" --use_style_template" + + run_cmd += run_cmd_sample( + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + output_dir, + ) + + if print_only_bool: + log.warning( + "Here is the trainer command as a reference. It will not be executed:\n" + ) + print(run_cmd) + + save_to_file(run_cmd) + else: + # Saving config file for model + current_datetime = datetime.now() + formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") + file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") + + log.info(f"Saving training config to {file_path}...") + + SaveConfigFile( + parameters=parameters, + file_path=file_path, + exclusion=["file_path", "save_as", "headless", "print_only"], + ) + + log.info(run_cmd) + + # Run the command + + executor.execute_command(run_cmd=run_cmd) + + # check if output_dir/last is a folder... therefore it is a diffuser model + last_dir = pathlib.Path(f"{output_dir}/{output_name}") + + if not last_dir.is_dir(): + # Copy inference model for v2 if required + save_inference_file(output_dir, v2, v_parameterization, output_name) + + +def ti_tab( + headless=False, +): + dummy_db_true = gr.Label(value=True, visible=False) + dummy_db_false = gr.Label(value=False, visible=False) + dummy_headless = gr.Label(value=headless, visible=False) + + with gr.Tab("Training"): + gr.Markdown("Train a TI using kohya textual inversion python code...") + + # Setup Configuration Files Gradio + config = ConfigurationFile(headless) + + source_model = SourceModel( + save_model_as_choices=[ + "ckpt", + "safetensors", + ], + headless=headless, + ) + + with gr.Tab("Folders"): + folders = Folders(headless=headless) + with gr.Tab("Parameters"): + with gr.Tab("Basic", elem_id="basic_tab"): + with gr.Row(): + weights = gr.Textbox( + label='Resume TI training', + placeholder='(Optional) Path to existing TI embedding file to keep training', + ) + weights_file_input = gr.Button( + "", + elem_id="open_folder_small", + visible=(not headless), + ) + weights_file_input.click( + get_file_path, + outputs=weights, + show_progress=False, + ) + with gr.Row(): + token_string = gr.Textbox( + label="Token string", + placeholder="eg: cat", + ) + init_word = gr.Textbox( + label="Init word", + value="*", + ) + num_vectors_per_token = gr.Slider( + minimum=1, + maximum=75, + value=1, + step=1, + label="Vectors", + ) + # max_train_steps = gr.Textbox( + # label='Max train steps', + # placeholder='(Optional) Maximum number of steps', + # ) + template = gr.Dropdown( + label="Template", + choices=[ + "caption", + "object template", + "style template", + ], + value="caption", + ) + basic_training = BasicTraining( + learning_rate_value="1e-5", + lr_scheduler_value="cosine", + lr_warmup_value="10", + sdxl_checkbox=source_model.sdxl_checkbox, + ) + + # Add SDXL Parameters + sdxl_params = SDXLParameters( + source_model.sdxl_checkbox, + show_sdxl_cache_text_encoder_outputs=False, + ) + + with gr.Tab("Advanced", elem_id="advanced_tab"): + advanced_training = AdvancedTraining(headless=headless) + advanced_training.color_aug.change( + color_aug_changed, + inputs=[advanced_training.color_aug], + outputs=[basic_training.cache_latents], + ) + + with gr.Tab("Samples", elem_id="samples_tab"): + sample = SampleImages() + + with gr.Tab("Dataset Preparation"): + gr.Markdown( + "This section provide Dreambooth tools to help setup your dataset..." + ) + gradio_dreambooth_folder_creation_tab( + train_data_dir_input=folders.train_data_dir, + reg_data_dir_input=folders.reg_data_dir, + output_dir_input=folders.output_dir, + logging_dir_input=folders.logging_dir, + headless=headless, + ) + gradio_dataset_balancing_tab(headless=headless) + + with gr.Row(): + button_run = gr.Button("Start training", variant="primary") + + button_stop_training = gr.Button("Stop training") + + button_print = gr.Button("Print training command") + + # Setup gradio tensorboard buttons + ( + button_start_tensorboard, + button_stop_tensorboard, + ) = gradio_tensorboard() + + button_start_tensorboard.click( + start_tensorboard, + inputs=[dummy_headless, folders.logging_dir], + show_progress=False, + ) + + button_stop_tensorboard.click( + stop_tensorboard, + show_progress=False, + ) + + settings_list = [ + source_model.pretrained_model_name_or_path, + source_model.v2, + source_model.v_parameterization, + source_model.sdxl_checkbox, + folders.logging_dir, + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + basic_training.max_resolution, + basic_training.learning_rate, + basic_training.lr_scheduler, + basic_training.lr_warmup, + basic_training.train_batch_size, + basic_training.epoch, + basic_training.save_every_n_epochs, + basic_training.mixed_precision, + basic_training.save_precision, + basic_training.seed, + basic_training.num_cpu_threads_per_process, + basic_training.cache_latents, + basic_training.cache_latents_to_disk, + basic_training.caption_extension, + basic_training.enable_bucket, + advanced_training.gradient_checkpointing, + advanced_training.full_fp16, + advanced_training.no_token_padding, + basic_training.stop_text_encoder_training, + basic_training.min_bucket_reso, + basic_training.max_bucket_reso, + advanced_training.xformers, + source_model.save_model_as, + advanced_training.shuffle_caption, + advanced_training.save_state, + advanced_training.resume, + advanced_training.prior_loss_weight, + advanced_training.color_aug, + advanced_training.flip_aug, + advanced_training.clip_skip, + advanced_training.num_processes, + advanced_training.num_machines, + advanced_training.multi_gpu, + advanced_training.gpu_ids, + advanced_training.vae, + folders.output_name, + advanced_training.max_token_length, + basic_training.max_train_epochs, + advanced_training.max_data_loader_n_workers, + advanced_training.mem_eff_attn, + advanced_training.gradient_accumulation_steps, + source_model.model_list, + token_string, + init_word, + num_vectors_per_token, + basic_training.max_train_steps, + weights, + template, + advanced_training.keep_tokens, + basic_training.lr_scheduler_num_cycles, + basic_training.lr_scheduler_power, + advanced_training.persistent_data_loader_workers, + advanced_training.bucket_no_upscale, + advanced_training.random_crop, + advanced_training.bucket_reso_steps, + advanced_training.v_pred_like_loss, + advanced_training.caption_dropout_every_n_epochs, + advanced_training.caption_dropout_rate, + basic_training.optimizer, + basic_training.optimizer_args, + basic_training.lr_scheduler_args, + advanced_training.noise_offset_type, + advanced_training.noise_offset, + advanced_training.adaptive_noise_scale, + advanced_training.multires_noise_iterations, + advanced_training.multires_noise_discount, + sample.sample_every_n_steps, + sample.sample_every_n_epochs, + sample.sample_sampler, + sample.sample_prompts, + advanced_training.additional_parameters, + advanced_training.vae_batch_size, + advanced_training.min_snr_gamma, + advanced_training.save_every_n_steps, + advanced_training.save_last_n_steps, + advanced_training.save_last_n_steps_state, + advanced_training.use_wandb, + advanced_training.wandb_api_key, + advanced_training.scale_v_pred_loss_like_noise_pred, + advanced_training.min_timestep, + advanced_training.max_timestep, + sdxl_params.sdxl_no_half_vae, + ] + + config.button_open_config.click( + open_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) + + config.button_load_config.click( + open_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name] + settings_list, + show_progress=False, + ) + + config.button_save_config.click( + save_configuration, + inputs=[dummy_db_false, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + config.button_save_as_config.click( + save_configuration, + inputs=[dummy_db_true, config.config_file_name] + settings_list, + outputs=[config.config_file_name], + show_progress=False, + ) + + button_run.click( + train_model, + inputs=[dummy_headless] + [dummy_db_false] + settings_list, + show_progress=False, + ) + + button_stop_training.click(executor.kill_command) + + button_print.click( + train_model, + inputs=[dummy_headless] + [dummy_db_true] + settings_list, + show_progress=False, + ) + + return ( + folders.train_data_dir, + folders.reg_data_dir, + folders.output_dir, + folders.logging_dir, + ) diff --git a/kohya_gui/utilities.py b/kohya_gui/utilities.py new file mode 100644 index 000000000..887de8be9 --- /dev/null +++ b/kohya_gui/utilities.py @@ -0,0 +1,41 @@ +# v1: initial release +# v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation +# v3.1: Adding captionning of images to utilities + +import gradio as gr +import os + +from .basic_caption_gui import gradio_basic_caption_gui_tab +from .convert_model_gui import gradio_convert_model_tab +from .blip_caption_gui import gradio_blip_caption_gui_tab +from .git_caption_gui import gradio_git_caption_gui_tab +from .wd14_caption_gui import gradio_wd14_caption_gui_tab +from .manual_caption_gui import gradio_manual_caption_gui_tab +from .group_images_gui import gradio_group_images_gui_tab + + +def utilities_tab( + train_data_dir_input=gr.Textbox(), + reg_data_dir_input=gr.Textbox(), + output_dir_input=gr.Textbox(), + logging_dir_input=gr.Textbox(), + enable_copy_info_button=bool(False), + enable_dreambooth_tab=True, + headless=False +): + with gr.Tab('Captioning'): + gradio_basic_caption_gui_tab(headless=headless) + gradio_blip_caption_gui_tab(headless=headless) + gradio_git_caption_gui_tab(headless=headless) + gradio_wd14_caption_gui_tab(headless=headless) + gradio_manual_caption_gui_tab(headless=headless) + gradio_convert_model_tab(headless=headless) + gradio_group_images_gui_tab(headless=headless) + + return ( + train_data_dir_input, + reg_data_dir_input, + output_dir_input, + logging_dir_input, + ) diff --git a/library/verify_lora_gui.py b/kohya_gui/verify_lora_gui.py similarity index 98% rename from library/verify_lora_gui.py rename to kohya_gui/verify_lora_gui.py index b98abf66d..7da872337 100644 --- a/library/verify_lora_gui.py +++ b/kohya_gui/verify_lora_gui.py @@ -8,7 +8,7 @@ get_file_path, ) -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/library/wd14_caption_gui.py b/kohya_gui/wd14_caption_gui.py similarity index 99% rename from library/wd14_caption_gui.py rename to kohya_gui/wd14_caption_gui.py index ae171a8ec..6bf72ec1d 100644 --- a/library/wd14_caption_gui.py +++ b/kohya_gui/wd14_caption_gui.py @@ -4,7 +4,7 @@ from .common_gui import get_folder_path, add_pre_postfix import os -from library.custom_logging import setup_logging +from .custom_logging import setup_logging # Set up logging log = setup_logging() diff --git a/lora_gui.py b/lora_gui.py index e0320b28e..dd18859fa 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -1,2026 +1,16 @@ +import argparse import gradio as gr -import json -import math import os -import argparse -import lycoris -from datetime import datetime -from library.common_gui import ( - get_file_path, - get_any_file_path, - get_saveasfile_path, - color_aug_changed, - run_cmd_advanced_training, - update_my_data, - check_if_model_exist, - output_message, - verify_image_folder_pattern, - SaveConfigFile, - save_to_file, - check_duplicate_filenames, -) -from library.class_configuration_file import ConfigurationFile -from library.class_source_model import SourceModel -from library.class_basic_training import BasicTraining -from library.class_advanced_training import AdvancedTraining -from library.class_sdxl_parameters import SDXLParameters -from library.class_folders import Folders -from library.class_command_executor import CommandExecutor -from library.tensorboard_gui import ( - gradio_tensorboard, - start_tensorboard, - stop_tensorboard, -) -from library.utilities import utilities_tab -from library.class_sample_images import SampleImages, run_cmd_sample -from library.class_lora_tab import LoRATools -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab +from kohya_gui.utilities import utilities_tab +from kohya_gui.lora_gui import lora_tab -from library.custom_logging import setup_logging -from library.localization_ext import add_javascript +from kohya_gui.custom_logging import setup_logging +from kohya_gui.localization_ext import add_javascript # Set up logging log = setup_logging() -# Setup command executor -executor = CommandExecutor() - -button_run = gr.Button("Start training", variant="primary") - -button_stop_training = gr.Button("Stop training") - -document_symbol = "\U0001F4C4" # 📄 - - -def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - fp8_base, - full_fp16, - # no_token_padding, - stop_text_encoder_training, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - lora_network_weights, - dim_from_weights, - color_aug, - flip_aug, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - gradient_accumulation_steps, - mem_eff_attn, - output_name, - model_list, - max_token_length, - max_train_epochs, - max_train_steps, - max_data_loader_n_workers, - network_alpha, - training_comment, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - max_grad_norm, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - LoRA_type, - factor, - use_cp, - use_tucker, - use_scalar, - rank_dropout_scale, - constrain, - rescaled, - train_norm, - decompose_both, - train_on_input, - conv_dim, - conv_alpha, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - down_lr_weight, - mid_lr_weight, - up_lr_weight, - block_lr_zero_threshold, - block_dims, - block_alphas, - conv_block_dims, - conv_block_alphas, - weighted_captions, - unit, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - scale_weight_norms, - network_dropout, - rank_dropout, - module_dropout, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - full_bf16, - min_timestep, - max_timestep, - vae, - LyCORIS_preset, - debiased_estimation_loss, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - original_file_path = file_path - - save_as_bool = True if save_as.get("label") == "True" else False - - if save_as_bool: - log.info("Save as...") - file_path = get_saveasfile_path(file_path) - else: - log.info("Save...") - if file_path == None or file_path == "": - file_path = get_saveasfile_path(file_path) - - # log.info(file_path) - - if file_path == None or file_path == "": - return original_file_path # In case a file_path was provided and the user decide to cancel the open action - - # Extract the destination directory from the file path - destination_directory = os.path.dirname(file_path) - - # Create the destination directory if it doesn't exist - if not os.path.exists(destination_directory): - os.makedirs(destination_directory) - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as"], - ) - - return file_path - - -def open_configuration( - ask_for_file, - apply_preset, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - fp8_base, - full_fp16, - # no_token_padding, - stop_text_encoder_training, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - lora_network_weights, - dim_from_weights, - color_aug, - flip_aug, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - gradient_accumulation_steps, - mem_eff_attn, - output_name, - model_list, - max_token_length, - max_train_epochs, - max_train_steps, - max_data_loader_n_workers, - network_alpha, - training_comment, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - max_grad_norm, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - LoRA_type, - factor, - use_cp, - use_tucker, - use_scalar, - rank_dropout_scale, - constrain, - rescaled, - train_norm, - decompose_both, - train_on_input, - conv_dim, - conv_alpha, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - down_lr_weight, - mid_lr_weight, - up_lr_weight, - block_lr_zero_threshold, - block_dims, - block_alphas, - conv_block_dims, - conv_block_alphas, - weighted_captions, - unit, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - scale_weight_norms, - network_dropout, - rank_dropout, - module_dropout, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - full_bf16, - min_timestep, - max_timestep, - vae, - LyCORIS_preset, - debiased_estimation_loss, - training_preset, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - ask_for_file = True if ask_for_file.get("label") == "True" else False - apply_preset = True if apply_preset.get("label") == "True" else False - - # Check if we are "applying" a preset or a config - if apply_preset: - if training_preset != "none": - log.info(f"Applying preset {training_preset}...") - file_path = f"./presets/lora/{training_preset}.json" - else: - # If not applying a preset, set the `training_preset` field to an empty string - # Find the index of the `training_preset` parameter using the `index()` method - training_preset_index = parameters.index(("training_preset", training_preset)) - - # Update the value of `training_preset` by directly assigning an empty string value - parameters[training_preset_index] = ("training_preset", "none") - - original_file_path = file_path - - if ask_for_file: - file_path = get_file_path(file_path) - - if not file_path == "" and not file_path == None: - # Load variables from JSON file - with open(file_path, "r") as f: - my_data = json.load(f) - log.info("Loading config...") - - # Update values to fix deprecated options, set appropriate optimizer if it is set to True, etc. - my_data = update_my_data(my_data) - else: - file_path = original_file_path # In case a file_path was provided and the user decides to cancel the open action - my_data = {} - - values = [file_path] - for key, value in parameters: - # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ["ask_for_file", "apply_preset", "file_path"]: - json_value = my_data.get(key) - # if isinstance(json_value, str) and json_value == '': - # # If the JSON value is an empty string, use the default value - # values.append(value) - # else: - # Otherwise, use the JSON value if not None, otherwise use the default value - values.append(json_value if json_value is not None else value) - - # This next section is about making the LoCon parameters visible if LoRA_type = 'Standard' - if my_data.get("LoRA_type", "Standard") in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/LoCon", - "LyCORIS/GLoRA", - }: - values.append(gr.Row(visible=True)) - else: - values.append(gr.Row(visible=False)) - - return tuple(values) - - -def train_model( - headless, - print_only, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - fp8_base, - full_fp16, - # no_token_padding, - stop_text_encoder_training_pct, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - lora_network_weights, - dim_from_weights, - color_aug, - flip_aug, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - gradient_accumulation_steps, - mem_eff_attn, - output_name, - model_list, # Keep this. Yes, it is unused here but required given the common list used - max_token_length, - max_train_epochs, - max_train_steps, - max_data_loader_n_workers, - network_alpha, - training_comment, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - max_grad_norm, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - LoRA_type, - factor, - use_cp, - use_tucker, - use_scalar, - rank_dropout_scale, - constrain, - rescaled, - train_norm, - decompose_both, - train_on_input, - conv_dim, - conv_alpha, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - down_lr_weight, - mid_lr_weight, - up_lr_weight, - block_lr_zero_threshold, - block_dims, - block_alphas, - conv_block_dims, - conv_block_alphas, - weighted_captions, - unit, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - scale_weight_norms, - network_dropout, - rank_dropout, - module_dropout, - sdxl_cache_text_encoder_outputs, - sdxl_no_half_vae, - full_bf16, - min_timestep, - max_timestep, - vae, - LyCORIS_preset, - debiased_estimation_loss, -): - # Get list of function parameters and values - parameters = list(locals().items()) - global command_running - - print_only_bool = True if print_only.get("label") == "True" else False - log.info(f"Start training LoRA {LoRA_type} ...") - headless_bool = True if headless.get("label") == "True" else False - - if pretrained_model_name_or_path == "": - output_message( - msg="Source model information is missing", headless=headless_bool - ) - return - - if train_data_dir == "": - output_message(msg="Image folder path is missing", headless=headless_bool) - return - - # Check if there are files with the same filename but different image extension... warn the user if it is the case. - check_duplicate_filenames(train_data_dir) - - if not os.path.exists(train_data_dir): - output_message(msg="Image folder does not exist", headless=headless_bool) - return - - if not verify_image_folder_pattern(train_data_dir): - return - - if reg_data_dir != "": - if not os.path.exists(reg_data_dir): - output_message( - msg="Regularisation folder does not exist", - headless=headless_bool, - ) - return - - if not verify_image_folder_pattern(reg_data_dir): - return - - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) - return - - if int(bucket_reso_steps) < 1: - output_message( - msg="Bucket resolution steps need to be greater than 0", - headless=headless_bool, - ) - return - - if noise_offset == "": - noise_offset = 0 - - if float(noise_offset) > 1 or float(noise_offset) < 0: - output_message( - msg="Noise offset need to be a value between 0 and 1", - headless=headless_bool, - ) - return - - # if float(noise_offset) > 0 and ( - # multires_noise_iterations > 0 or multires_noise_discount > 0 - # ): - # output_message( - # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", - # title='Error', - # headless=headless_bool, - # ) - # return - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - if stop_text_encoder_training_pct > 0: - output_message( - msg='Output "stop text encoder training" is not yet supported. Ignoring', - headless=headless_bool, - ) - stop_text_encoder_training_pct = 0 - - if check_if_model_exist( - output_name, output_dir, save_model_as, headless=headless_bool - ): - return - - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - - # If string is empty set string to 0. - if text_encoder_lr == "": - text_encoder_lr = 0 - if unet_lr == "": - unet_lr = 0 - - # Get a list of all subfolders in train_data_dir - subfolders = [ - f - for f in os.listdir(train_data_dir) - if os.path.isdir(os.path.join(train_data_dir, f)) - ] - - total_steps = 0 - - # Loop through each subfolder and extract the number of repeats - for folder in subfolders: - try: - # Extract the number of repeats from the folder name - repeats = int(folder.split("_")[0]) - - # Count the number of images in the folder - num_images = len( - [ - f - for f, lower_f in ( - (file, file.lower()) - for file in os.listdir(os.path.join(train_data_dir, folder)) - ) - if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) - ] - ) - - log.info(f"Folder {folder}: {num_images} images found") - - # Calculate the total number of steps for this folder - steps = repeats * num_images - - # log.info the result - log.info(f"Folder {folder}: {steps} steps") - - total_steps += steps - - except ValueError: - # Handle the case where the folder name does not contain an underscore - log.info(f"Error: '{folder}' does not contain an underscore, skipping...") - - if reg_data_dir == "": - reg_factor = 1 - else: - log.warning( - "Regularisation images are used... Will double the number of steps required..." - ) - reg_factor = 2 - - log.info(f"Total steps: {total_steps}") - log.info(f"Train batch size: {train_batch_size}") - log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}") - log.info(f"Epoch: {epoch}") - log.info(f"Regulatization factor: {reg_factor}") - - if max_train_steps == "" or max_train_steps == "0": - # calculate max_train_steps - max_train_steps = int( - math.ceil( - float(total_steps) - / int(train_batch_size) - / int(gradient_accumulation_steps) - * int(epoch) - * int(reg_factor) - ) - ) - log.info( - f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}" - ) - - # calculate stop encoder training - if stop_text_encoder_training_pct == None: - stop_text_encoder_training = 0 - else: - stop_text_encoder_training = math.ceil( - float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) - ) - log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") - - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - log.info(f"lr_warmup_steps = {lr_warmup_steps}") - - run_cmd = "accelerate launch" - - run_cmd += run_cmd_advanced_training( - num_processes=num_processes, - num_machines=num_machines, - multi_gpu=multi_gpu, - gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process, - ) - - if sdxl: - run_cmd += f' "./sdxl_train_network.py"' - else: - run_cmd += f' "./train_network.py"' - - if LoRA_type == "LyCORIS/Diag-OFT": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "constrain={constrain}" "rescaled={rescaled}" "algo=diag-oft" ' - - if LoRA_type == "LyCORIS/DyLoRA": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "use_tucker={use_tucker}" "block_size={unit}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "algo=dylora" "train_norm={train_norm}"' - - if LoRA_type == "LyCORIS/GLoRA": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "rank_dropout_scale={rank_dropout_scale}" "algo=glora" "train_norm={train_norm}"' - - if LoRA_type == "LyCORIS/iA3": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "train_on_input={train_on_input}" "algo=ia3"' - - if LoRA_type == "LoCon" or LoRA_type == "LyCORIS/LoCon": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=locon" "train_norm={train_norm}"' - - if LoRA_type == "LyCORIS/LoHa": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=loha" "train_norm={train_norm}"' - - if LoRA_type == "LyCORIS/LoKr": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "factor={factor}" "use_cp={use_cp}" "use_scalar={use_scalar}" "decompose_both={decompose_both}" "rank_dropout_scale={rank_dropout_scale}" "algo=lokr" "train_norm={train_norm}"' - - if LoRA_type == "LyCORIS/Native Fine-Tuning": - network_module = "lycoris.kohya" - network_args = f' "preset={LyCORIS_preset}" "rank_dropout={rank_dropout}" "module_dropout={module_dropout}" "use_tucker={use_tucker}" "use_scalar={use_scalar}" "rank_dropout_scale={rank_dropout_scale}" "algo=full" "train_norm={train_norm}"' - - if LoRA_type in ["Kohya LoCon", "Standard"]: - kohya_lora_var_list = [ - "down_lr_weight", - "mid_lr_weight", - "up_lr_weight", - "block_lr_zero_threshold", - "block_dims", - "block_alphas", - "conv_block_dims", - "conv_block_alphas", - "rank_dropout", - "module_dropout", - ] - - network_module = "networks.lora" - kohya_lora_vars = { - key: value - for key, value in vars().items() - if key in kohya_lora_var_list and value - } - - network_args = "" - if LoRA_type == "Kohya LoCon": - network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' - - for key, value in kohya_lora_vars.items(): - if value: - network_args += f' {key}="{value}"' - - if LoRA_type in [ - "LoRA-FA", - ]: - kohya_lora_var_list = [ - "down_lr_weight", - "mid_lr_weight", - "up_lr_weight", - "block_lr_zero_threshold", - "block_dims", - "block_alphas", - "conv_block_dims", - "conv_block_alphas", - "rank_dropout", - "module_dropout", - ] - - network_module = "networks.lora_fa" - kohya_lora_vars = { - key: value - for key, value in vars().items() - if key in kohya_lora_var_list and value - } - - network_args = "" - if LoRA_type == "Kohya LoCon": - network_args += f' conv_dim="{conv_dim}" conv_alpha="{conv_alpha}"' - - for key, value in kohya_lora_vars.items(): - if value: - network_args += f' {key}="{value}"' - - if LoRA_type in ["Kohya DyLoRA"]: - kohya_lora_var_list = [ - "conv_dim", - "conv_alpha", - "down_lr_weight", - "mid_lr_weight", - "up_lr_weight", - "block_lr_zero_threshold", - "block_dims", - "block_alphas", - "conv_block_dims", - "conv_block_alphas", - "rank_dropout", - "module_dropout", - "unit", - ] - - network_module = "networks.dylora" - kohya_lora_vars = { - key: value - for key, value in vars().items() - if key in kohya_lora_var_list and value - } - - network_args = "" - - for key, value in kohya_lora_vars.items(): - if value: - network_args += f' {key}="{value}"' - - network_train_text_encoder_only = False - network_train_unet_only = False - - # Convert learning rates to float once and store the result for re-use - if text_encoder_lr is None: - output_message( - msg="Please input valid Text Encoder learning rate (between 0 and 1)", headless=headless_bool - ) - return - if unet_lr is None: - output_message( - msg="Please input valid Unet learning rate (between 0 and 1)", headless=headless_bool - ) - return - text_encoder_lr_float = float(text_encoder_lr) - unet_lr_float = float(unet_lr) - - - - # Determine the training configuration based on learning rate values - if text_encoder_lr_float == 0 and unet_lr_float == 0: - if float(learning_rate) == 0: - output_message( - msg="Please input learning rate values.", headless=headless_bool - ) - return - elif text_encoder_lr_float != 0 and unet_lr_float == 0: - network_train_text_encoder_only = True - elif text_encoder_lr_float == 0 and unet_lr_float != 0: - network_train_unet_only = True - # If both learning rates are non-zero, no specific flags need to be set - - run_cmd += run_cmd_advanced_training( - adaptive_noise_scale=adaptive_noise_scale, - additional_parameters=additional_parameters, - bucket_no_upscale=bucket_no_upscale, - bucket_reso_steps=bucket_reso_steps, - cache_latents=cache_latents, - cache_latents_to_disk=cache_latents_to_disk, - cache_text_encoder_outputs=True if sdxl and sdxl_cache_text_encoder_outputs else None, - caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, - caption_extension=caption_extension, - clip_skip=clip_skip, - color_aug=color_aug, - debiased_estimation_loss=debiased_estimation_loss, - dim_from_weights=dim_from_weights, - enable_bucket=enable_bucket, - epoch=epoch, - flip_aug=flip_aug, - fp8_base=fp8_base, - full_bf16=full_bf16, - full_fp16=full_fp16, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_checkpointing=gradient_checkpointing, - keep_tokens=keep_tokens, - learning_rate=learning_rate, - logging_dir=logging_dir, - lora_network_weights=lora_network_weights, - lr_scheduler=lr_scheduler, - lr_scheduler_args=lr_scheduler_args, - lr_scheduler_num_cycles=lr_scheduler_num_cycles, - lr_scheduler_power=lr_scheduler_power, - lr_warmup_steps=lr_warmup_steps, - max_bucket_reso=max_bucket_reso, - max_data_loader_n_workers=max_data_loader_n_workers, - max_grad_norm=max_grad_norm, - max_resolution=max_resolution, - max_timestep=max_timestep, - max_token_length=max_token_length, - max_train_epochs=max_train_epochs, - max_train_steps=max_train_steps, - mem_eff_attn=mem_eff_attn, - min_bucket_reso=min_bucket_reso, - min_snr_gamma=min_snr_gamma, - min_timestep=min_timestep, - mixed_precision=mixed_precision, - multires_noise_discount=multires_noise_discount, - multires_noise_iterations=multires_noise_iterations, - network_alpha=network_alpha, - network_args=network_args, - network_dim=network_dim, - network_dropout=network_dropout, - network_module=network_module, - network_train_unet_only=network_train_unet_only, - network_train_text_encoder_only=network_train_text_encoder_only, - no_half_vae=True if sdxl and sdxl_no_half_vae else None, - # no_token_padding=no_token_padding, - noise_offset=noise_offset, - noise_offset_type=noise_offset_type, - optimizer=optimizer, - optimizer_args=optimizer_args, - output_dir=output_dir, - output_name=output_name, - persistent_data_loader_workers=persistent_data_loader_workers, - pretrained_model_name_or_path=pretrained_model_name_or_path, - prior_loss_weight=prior_loss_weight, - random_crop=random_crop, - reg_data_dir=reg_data_dir, - resume=resume, - save_every_n_epochs=save_every_n_epochs, - save_every_n_steps=save_every_n_steps, - save_last_n_steps=save_last_n_steps, - save_last_n_steps_state=save_last_n_steps_state, - save_model_as=save_model_as, - save_precision=save_precision, - save_state=save_state, - scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, - scale_weight_norms=scale_weight_norms, - seed=seed, - shuffle_caption=shuffle_caption, - stop_text_encoder_training=stop_text_encoder_training, - text_encoder_lr=text_encoder_lr, - train_batch_size=train_batch_size, - train_data_dir=train_data_dir, - training_comment=training_comment, - unet_lr=unet_lr, - use_wandb=use_wandb, - v2=v2, - v_parameterization=v_parameterization, - v_pred_like_loss=v_pred_like_loss, - vae=vae, - vae_batch_size=vae_batch_size, - wandb_api_key=wandb_api_key, - weighted_captions=weighted_captions, - xformers=xformers, - ) - - run_cmd += run_cmd_sample( - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - output_dir, - ) - - if print_only_bool: - log.warning( - "Here is the trainer command as a reference. It will not be executed:\n" - ) - print(run_cmd) - - save_to_file(run_cmd) - else: - # Saving config file for model - current_datetime = datetime.now() - formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") - - log.info(f"Saving training config to {file_path}...") - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as", "headless", "print_only"], - ) - - log.info(run_cmd) - # Run the command - executor.execute_command(run_cmd=run_cmd) - - # # check if output_dir/last is a folder... therefore it is a diffuser model - # last_dir = pathlib.Path(f'{output_dir}/{output_name}') - - # if not last_dir.is_dir(): - # # Copy inference model for v2 if required - # save_inference_file( - # output_dir, v2, v_parameterization, output_name - # ) - - -def lora_tab( - train_data_dir_input=gr.Textbox(), - reg_data_dir_input=gr.Textbox(), - output_dir_input=gr.Textbox(), - logging_dir_input=gr.Textbox(), - headless=False, -): - dummy_db_true = gr.Label(value=True, visible=False) - dummy_db_false = gr.Label(value=False, visible=False) - dummy_headless = gr.Label(value=headless, visible=False) - - with gr.Tab("Training"): - gr.Markdown( - "Train a custom model using kohya train network LoRA python code..." - ) - - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel( - save_model_as_choices=[ - "ckpt", - "safetensors", - ], - headless=headless, - ) - - with gr.Tab("Folders"): - folders = Folders(headless=headless) - - with gr.Tab("Parameters"): - - def list_presets(path): - json_files = [] - - # Insert an empty string at the beginning - json_files.insert(0, "none") - - for file in os.listdir(path): - if file.endswith(".json"): - json_files.append(os.path.splitext(file)[0]) - - user_presets_path = os.path.join(path, "user_presets") - if os.path.isdir(user_presets_path): - for file in os.listdir(user_presets_path): - if file.endswith(".json"): - preset_name = os.path.splitext(file)[0] - json_files.append(os.path.join("user_presets", preset_name)) - - return json_files - - training_preset = gr.Dropdown( - label="Presets", - choices=list_presets("./presets/lora"), - elem_id="myDropdown", - value="none" - ) - - with gr.Tab("Basic", elem_id="basic_tab"): - with gr.Row(): - LoRA_type = gr.Dropdown( - label="LoRA type", - choices=[ - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/DyLoRA", - "LyCORIS/iA3", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - "Standard", - ], - value="Standard", - ) - LyCORIS_preset = gr.Dropdown( - label="LyCORIS Preset", - choices=[ - "attn-mlp", - "attn-only", - "full", - "full-lin", - "unet-transformer-only", - "unet-convblock-only", - ], - value="full", - visible=False, - interactive=True - # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" - ) - with gr.Group(): - with gr.Row(): - lora_network_weights = gr.Textbox( - label="LoRA network weights", - placeholder="(Optional)", - info="Path to an existing LoRA network weights to resume training from", - ) - lora_network_weights_file = gr.Button( - document_symbol, - elem_id="open_folder_small", - visible=(not headless), - ) - lora_network_weights_file.click( - get_any_file_path, - inputs=[lora_network_weights], - outputs=lora_network_weights, - show_progress=False, - ) - dim_from_weights = gr.Checkbox( - label="DIM from weights", - value=False, - info="Automatically determine the dim(rank) from the weight file.", - ) - basic_training = BasicTraining( - learning_rate_value="0.0001", - lr_scheduler_value="cosine", - lr_warmup_value="10", - sdxl_checkbox=source_model.sdxl_checkbox, - ) - - with gr.Row(): - text_encoder_lr = gr.Number( - label="Text Encoder learning rate", - value="0.0001", - info="Optional", - minimum=0, - maximum=1, - ) - - unet_lr = gr.Number( - label="Unet learning rate", - value="0.0001", - info="Optional", - minimum=0, - maximum=1, - ) - - # Add SDXL Parameters - sdxl_params = SDXLParameters(source_model.sdxl_checkbox) - - with gr.Row(): - factor = gr.Slider( - label="LoKr factor", - value=-1, - minimum=-1, - maximum=64, - step=1, - visible=False, - ) - use_cp = gr.Checkbox( - value=False, - label="Use CP decomposition", - info="A two-step approach utilizing tensor decomposition and fine-tuning to accelerate convolution layers in large neural networks, resulting in significant CPU speedups with minor accuracy drops.", - visible=False, - ) - use_tucker = gr.Checkbox( - value=False, - label="Use Tucker decomposition", - info="Efficiently decompose tensor shapes, resulting in a sequence of convolution layers with varying dimensions and Hadamard product implementation through multiplication of two distinct tensors.", - visible=False, - ) - use_scalar = gr.Checkbox( - value=False, - label="Use Scalar", - info="Train an additional scalar in front of the weight difference, use a different weight initialization strategy.", - visible=False, - ) - rank_dropout_scale = gr.Checkbox( - value=False, - label="Rank Dropout Scale", - info="Adjusts the scale of the rank dropout to maintain the average dropout rate, ensuring more consistent regularization across different layers.", - visible=False, - ) - constrain = gr.Number( - value="0.0", - label="Constrain OFT", - info="Limits the norm of the oft_blocks, ensuring that their magnitude does not exceed a specified threshold, thus controlling the extent of the transformation applied.", - visible=False, - ) - rescaled = gr.Checkbox( - value=False, - label="Rescaled OFT", - info="applies an additional scaling factor to the oft_blocks, allowing for further adjustment of their impact on the model's transformations.", - visible=False, - ) - train_norm = gr.Checkbox( - value=False, - label="Train Norm", - info="Selects trainable layers in a network, but trains normalization layers identically across methods as they lack matrix decomposition.", - visible=False, - ) - decompose_both = gr.Checkbox( - value=False, - label="LoKr decompose both", - info=" Controls whether both input and output dimensions of the layer's weights are decomposed into smaller matrices for reparameterization.", - visible=False, - ) - train_on_input = gr.Checkbox( - value=True, - label="iA3 train on input", - info="Set if we change the information going into the system (True) or the information coming out of it (False).", - visible=False, - ) - - with gr.Row() as network_row: - network_dim = gr.Slider( - minimum=1, - maximum=512, - label="Network Rank (Dimension)", - value=8, - step=1, - interactive=True, - ) - network_alpha = gr.Slider( - minimum=0.1, - maximum=1024, - label="Network Alpha", - value=1, - step=0.1, - interactive=True, - info="alpha for LoRA weight scaling", - ) - with gr.Row(visible=False) as convolution_row: - # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) - conv_dim = gr.Slider( - minimum=0, - maximum=512, - value=1, - step=1, - label="Convolution Rank (Dimension)", - ) - conv_alpha = gr.Slider( - minimum=0, - maximum=512, - value=1, - step=1, - label="Convolution Alpha", - ) - with gr.Row(): - scale_weight_norms = gr.Slider( - label="Scale weight norms", - value=0, - minimum=0, - maximum=10, - step=0.01, - info="Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR #545 on kohya_ss/sd_scripts repo for details. Recommended setting: 1. Higher is weaker, lower is stronger.", - interactive=True, - ) - network_dropout = gr.Slider( - label="Network dropout", - value=0, - minimum=0, - maximum=1, - step=0.01, - info="Is a normal probability dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Recommended range 0.1 to 0.5", - ) - rank_dropout = gr.Slider( - label="Rank dropout", - value=0, - minimum=0, - maximum=1, - step=0.01, - info="can specify `rank_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", - ) - module_dropout = gr.Slider( - label="Module dropout", - value=0.0, - minimum=0.0, - maximum=1.0, - step=0.01, - info="can specify `module_dropout` to dropout each rank with specified probability. Recommended range 0.1 to 0.3", - ) - with gr.Row(visible=False) as kohya_dylora: - unit = gr.Slider( - minimum=1, - maximum=64, - label="DyLoRA Unit / Block size", - value=1, - step=1, - interactive=True, - ) - - # Show or hide LoCon conv settings depending on LoRA type selection - def update_LoRA_settings( - LoRA_type, - conv_dim, - network_dim, - ): - log.info("LoRA type changed...") - - lora_settings_config = { - "network_row": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "Standard", - }, - }, - }, - "convolution_row": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/LoCon", - "LyCORIS/GLoRA", - }, - }, - }, - "kohya_advanced_lora": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - }, - }, - }, - "kohya_dylora": { - "gr_type": gr.Row, - "update_params": { - "visible": LoRA_type - in { - "Kohya DyLoRA", - "LyCORIS/DyLoRA", - }, - }, - }, - "lora_network_weights": { - "gr_type": gr.Textbox, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", - }, - }, - }, - "lora_network_weights_file": { - "gr_type": gr.Button, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", - }, - }, - }, - "dim_from_weights": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "Standard", - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", - } - }, - }, - "factor": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/LoKr", - }, - }, - }, - "conv_dim": { - "gr_type": gr.Slider, - "update_params": { - "maximum": 100000 - if LoRA_type - in { - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Diag-OFT", - } - else 512, - "value": conv_dim, # if conv_dim > 512 else conv_dim, - }, - }, - "network_dim": { - "gr_type": gr.Slider, - "update_params": { - "maximum": 100000 - if LoRA_type - in { - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Diag-OFT", - } - else 512, - "value": network_dim, # if network_dim > 512 else network_dim, - }, - }, - "use_cp": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/LoKr", - }, - }, - }, - "use_tucker": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/Native Fine-Tuning", - }, - }, - }, - "use_scalar": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/Diag-OFT", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - }, - }, - }, - "rank_dropout_scale": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - }, - }, - }, - "constrain": { - "gr_type": gr.Number, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/Diag-OFT", - }, - }, - }, - "rescaled": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/Diag-OFT", - }, - }, - }, - "train_norm": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/DyLoRA", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - }, - }, - }, - "decompose_both": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type in {"LyCORIS/LoKr"}, - }, - }, - "train_on_input": { - "gr_type": gr.Checkbox, - "update_params": { - "visible": LoRA_type in {"LyCORIS/iA3"}, - }, - }, - "scale_weight_norms": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoHa", - "LyCORIS/LoCon", - "LyCORIS/LoKr", - "Standard", - }, - }, - }, - "network_dropout": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Diag-OFT", - "LyCORIS/DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - "Standard", - }, - }, - }, - "rank_dropout": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "Kohya DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKR", - "Kohya LoCon", - "LoRA-FA", - "LyCORIS/Native Fine-Tuning", - "Standard", - }, - }, - }, - "module_dropout": { - "gr_type": gr.Slider, - "update_params": { - "visible": LoRA_type - in { - "LoCon", - "LyCORIS/Diag-OFT", - "Kohya DyLoRA", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKR", - "Kohya LoCon", - "LyCORIS/Native Fine-Tuning", - "LoRA-FA", - "Standard", - }, - }, - }, - "LyCORIS_preset": { - "gr_type": gr.Dropdown, - "update_params": { - "visible": LoRA_type - in { - "LyCORIS/DyLoRA", - "LyCORIS/iA3", - "LyCORIS/Diag-OFT", - "LyCORIS/GLoRA", - "LyCORIS/LoCon", - "LyCORIS/LoHa", - "LyCORIS/LoKr", - "LyCORIS/Native Fine-Tuning", - }, - }, - }, - } - - results = [] - for attr, settings in lora_settings_config.items(): - update_params = settings["update_params"] - - results.append(settings["gr_type"](**update_params)) - - return tuple(results) - - with gr.Tab("Advanced", elem_id="advanced_tab"): - # with gr.Accordion('Advanced Configuration', open=False): - with gr.Row(visible=True) as kohya_advanced_lora: - with gr.Tab(label="Weights"): - with gr.Row(visible=True): - down_lr_weight = gr.Textbox( - label="Down LR weights", - placeholder="(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1", - info="Specify the learning rate weight of the down blocks of U-Net.", - ) - mid_lr_weight = gr.Textbox( - label="Mid LR weights", - placeholder="(Optional) eg: 0.5", - info="Specify the learning rate weight of the mid block of U-Net.", - ) - up_lr_weight = gr.Textbox( - label="Up LR weights", - placeholder="(Optional) eg: 0,0,0,0,0,0,1,1,1,1,1,1", - info="Specify the learning rate weight of the up blocks of U-Net. The same as down_lr_weight.", - ) - block_lr_zero_threshold = gr.Textbox( - label="Blocks LR zero threshold", - placeholder="(Optional) eg: 0.1", - info="If the weight is not more than this value, the LoRA module is not created. The default is 0.", - ) - with gr.Tab(label="Blocks"): - with gr.Row(visible=True): - block_dims = gr.Textbox( - label="Block dims", - placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", - info="Specify the dim (rank) of each block. Specify 25 numbers.", - ) - block_alphas = gr.Textbox( - label="Block alphas", - placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", - info="Specify the alpha of each block. Specify 25 numbers as with block_dims. If omitted, the value of network_alpha is used.", - ) - with gr.Tab(label="Conv"): - with gr.Row(visible=True): - conv_block_dims = gr.Textbox( - label="Conv dims", - placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", - info="Extend LoRA to Conv2d 3x3 and specify the dim (rank) of each block. Specify 25 numbers.", - ) - conv_block_alphas = gr.Textbox( - label="Conv alphas", - placeholder="(Optional) eg: 2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2", - info="Specify the alpha of each block when expanding LoRA to Conv2d 3x3. Specify 25 numbers. If omitted, the value of conv_alpha is used.", - ) - advanced_training = AdvancedTraining( - headless=headless, training_type="lora" - ) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], - ) - - with gr.Tab("Samples", elem_id="samples_tab"): - sample = SampleImages() - - LoRA_type.change( - update_LoRA_settings, - inputs=[ - LoRA_type, - conv_dim, - network_dim, - ], - outputs=[ - network_row, - convolution_row, - kohya_advanced_lora, - kohya_dylora, - lora_network_weights, - lora_network_weights_file, - dim_from_weights, - factor, - conv_dim, - network_dim, - use_cp, - use_tucker, - use_scalar, - rank_dropout_scale, - constrain, - rescaled, - train_norm, - decompose_both, - train_on_input, - scale_weight_norms, - network_dropout, - rank_dropout, - module_dropout, - LyCORIS_preset, - ], - ) - - with gr.Tab("Dataset Preparation"): - gr.Markdown( - "This section provide Dreambooth tools to help setup your dataset..." - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) - gradio_dataset_balancing_tab(headless=headless) - - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") - - button_stop_training = gr.Button("Stop training") - - button_print = gr.Button("Print training command") - - # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() - - button_start_tensorboard.click( - start_tensorboard, - inputs=[dummy_headless, folders.logging_dir], - show_progress=False, - ) - - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) - - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.fp8_base, - advanced_training.full_fp16, - # advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - basic_training.min_bucket_reso, - basic_training.max_bucket_reso, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - text_encoder_lr, - unet_lr, - network_dim, - lora_network_weights, - dim_from_weights, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, - advanced_training.gradient_accumulation_steps, - advanced_training.mem_eff_attn, - folders.output_name, - source_model.model_list, - advanced_training.max_token_length, - basic_training.max_train_epochs, - basic_training.max_train_steps, - advanced_training.max_data_loader_n_workers, - network_alpha, - folders.training_comment, - advanced_training.keep_tokens, - basic_training.lr_scheduler_num_cycles, - basic_training.lr_scheduler_power, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.v_pred_like_loss, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - basic_training.lr_scheduler_args, - basic_training.max_grad_norm, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - LoRA_type, - factor, - use_cp, - use_tucker, - use_scalar, - rank_dropout_scale, - constrain, - rescaled, - train_norm, - decompose_both, - train_on_input, - conv_dim, - conv_alpha, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - down_lr_weight, - mid_lr_weight, - up_lr_weight, - block_lr_zero_threshold, - block_dims, - block_alphas, - conv_block_dims, - conv_block_alphas, - advanced_training.weighted_captions, - unit, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - scale_weight_norms, - network_dropout, - rank_dropout, - module_dropout, - sdxl_params.sdxl_cache_text_encoder_outputs, - sdxl_params.sdxl_no_half_vae, - advanced_training.full_bf16, - advanced_training.min_timestep, - advanced_training.max_timestep, - advanced_training.vae, - LyCORIS_preset, - advanced_training.debiased_estimation_loss, - ] - - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, dummy_db_false, config.config_file_name] - + settings_list - + [training_preset], - outputs=[config.config_file_name] - + settings_list - + [training_preset, convolution_row], - show_progress=False, - ) - - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, dummy_db_false, config.config_file_name] - + settings_list - + [training_preset], - outputs=[config.config_file_name] - + settings_list - + [training_preset, convolution_row], - show_progress=False, - ) - - training_preset.input( - open_configuration, - inputs=[dummy_db_false, dummy_db_true, config.config_file_name] - + settings_list - + [training_preset], - outputs=[gr.Textbox()] + settings_list + [training_preset, convolution_row], - show_progress=False, - ) - - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) - - button_stop_training.click(executor.kill_command) - - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) - - with gr.Tab("Tools"): - lora_tools = LoRATools(folders=folders, headless=headless) - - with gr.Tab("Guides"): - gr.Markdown("This section provide Various LoRA guides and information...") - if os.path.exists("./docs/LoRA/top_level.md"): - with open( - os.path.join("./docs/LoRA/top_level.md"), "r", encoding="utf8" - ) as file: - guides_top_level = file.read() + "\n" - gr.Markdown(guides_top_level) - - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) - def UI(**kwargs): try: diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index 70ac5e21d..0ebc0f209 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -1,967 +1,16 @@ -# v1: initial release -# v2: add open and save folder icons -# v3: Add new Utilities tab for Dreambooth folder preparation -# v3.1: Adding captionning of images to utilities - +import argparse import gradio as gr -import json -import math import os -import subprocess -import pathlib -import argparse -from datetime import datetime -from library.common_gui import ( - get_file_path, - get_saveasfile_path, - color_aug_changed, - save_inference_file, - run_cmd_advanced_training, - update_my_data, - check_if_model_exist, - output_message, - verify_image_folder_pattern, - SaveConfigFile, - save_to_file, -) -from library.class_configuration_file import ConfigurationFile -from library.class_source_model import SourceModel -from library.class_basic_training import BasicTraining -from library.class_advanced_training import AdvancedTraining -from library.class_folders import Folders -from library.class_sdxl_parameters import SDXLParameters -from library.class_command_executor import CommandExecutor -from library.tensorboard_gui import ( - gradio_tensorboard, - start_tensorboard, - stop_tensorboard, -) -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab -from library.utilities import utilities_tab -from library.class_sample_images import SampleImages, run_cmd_sample - -from library.custom_logging import setup_logging -from library.localization_ext import add_javascript - -# Set up logging -log = setup_logging() - -# Setup command executor -executor = CommandExecutor() - - -def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - token_string, - init_word, - num_vectors_per_token, - max_train_steps, - weights, - template, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - min_timestep, - max_timestep, - sdxl_no_half_vae, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - original_file_path = file_path - - save_as_bool = True if save_as.get("label") == "True" else False - - if save_as_bool: - log.info("Save as...") - file_path = get_saveasfile_path(file_path) - else: - log.info("Save...") - if file_path == None or file_path == "": - file_path = get_saveasfile_path(file_path) - - # log.info(file_path) - - if file_path == None or file_path == "": - return original_file_path # In case a file_path was provided and the user decide to cancel the open action - - # Extract the destination directory from the file path - destination_directory = os.path.dirname(file_path) - - # Create the destination directory if it doesn't exist - if not os.path.exists(destination_directory): - os.makedirs(destination_directory) - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as"], - ) - - return file_path - - -def open_configuration( - ask_for_file, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - token_string, - init_word, - num_vectors_per_token, - max_train_steps, - weights, - template, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - min_timestep, - max_timestep, - sdxl_no_half_vae, -): - # Get list of function parameters and values - parameters = list(locals().items()) - - ask_for_file = True if ask_for_file.get("label") == "True" else False - - original_file_path = file_path - - if ask_for_file: - file_path = get_file_path(file_path) - - if not file_path == "" and not file_path == None: - # load variables from JSON file - with open(file_path, "r") as f: - my_data = json.load(f) - log.info("Loading config...") - # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True - my_data = update_my_data(my_data) - else: - file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action - my_data = {} - values = [file_path] - for key, value in parameters: - # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found - if not key in ["ask_for_file", "file_path"]: - values.append(my_data.get(key, value)) - return tuple(values) +from kohya_gui.textual_inversion_gui import ti_tab +from kohya_gui.utilities import utilities_tab +from kohya_gui.custom_logging import setup_logging +from kohya_gui.localization_ext import add_javascript -def train_model( - headless, - print_only, - pretrained_model_name_or_path, - v2, - v_parameterization, - sdxl, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - cache_latents_to_disk, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training_pct, - min_bucket_reso, - max_bucket_reso, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - num_processes, - num_machines, - multi_gpu, - gpu_ids, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, # Keep this. Yes, it is unused here but required given the common list used - token_string, - init_word, - num_vectors_per_token, - max_train_steps, - weights, - template, - keep_tokens, - lr_scheduler_num_cycles, - lr_scheduler_power, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - v_pred_like_loss, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - lr_scheduler_args, - noise_offset_type, - noise_offset, - adaptive_noise_scale, - multires_noise_iterations, - multires_noise_discount, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, - save_every_n_steps, - save_last_n_steps, - save_last_n_steps_state, - use_wandb, - wandb_api_key, - scale_v_pred_loss_like_noise_pred, - min_timestep, - max_timestep, - sdxl_no_half_vae, -): - # Get list of function parameters and values - parameters = list(locals().items()) - print_only_bool = True if print_only.get("label") == "True" else False - log.info(f"Start training TI...") - - headless_bool = True if headless.get("label") == "True" else False - - if pretrained_model_name_or_path == "": - output_message( - msg="Source model information is missing", headless=headless_bool - ) - return - - if train_data_dir == "": - output_message(msg="Image folder path is missing", headless=headless_bool) - return - - if not os.path.exists(train_data_dir): - output_message(msg="Image folder does not exist", headless=headless_bool) - return - - if not verify_image_folder_pattern(train_data_dir): - return - - if reg_data_dir != "": - if not os.path.exists(reg_data_dir): - output_message( - msg="Regularisation folder does not exist", - headless=headless_bool, - ) - return - - if not verify_image_folder_pattern(reg_data_dir): - return - - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) - return - - if token_string == "": - output_message(msg="Token string is missing", headless=headless_bool) - return - - if init_word == "": - output_message(msg="Init word is missing", headless=headless_bool) - return - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - if check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): - return - - # if float(noise_offset) > 0 and ( - # multires_noise_iterations > 0 or multires_noise_discount > 0 - # ): - # output_message( - # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", - # title='Error', - # headless=headless_bool, - # ) - # return - - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - - # Get a list of all subfolders in train_data_dir - subfolders = [ - f - for f in os.listdir(train_data_dir) - if os.path.isdir(os.path.join(train_data_dir, f)) - ] - - total_steps = 0 - - # Loop through each subfolder and extract the number of repeats - for folder in subfolders: - # Extract the number of repeats from the folder name - repeats = int(folder.split("_")[0]) - - # Count the number of images in the folder - num_images = len( - [ - f - for f, lower_f in ( - (file, file.lower()) - for file in os.listdir(os.path.join(train_data_dir, folder)) - ) - if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) - ] - ) - - # Calculate the total number of steps for this folder - steps = repeats * num_images - total_steps += steps - - # Print the result - log.info(f"Folder {folder}: {steps} steps") - - # Print the result - # log.info(f"{total_steps} total steps") - - if reg_data_dir == "": - reg_factor = 1 - else: - log.info( - "Regularisation images are used... Will double the number of steps required..." - ) - reg_factor = 2 - - # calculate max_train_steps - if max_train_steps == "" or max_train_steps == "0": - max_train_steps = int( - math.ceil( - float(total_steps) - / int(train_batch_size) - / int(gradient_accumulation_steps) - * int(epoch) - * int(reg_factor) - ) - ) - else: - max_train_steps = int(max_train_steps) - - log.info(f"max_train_steps = {max_train_steps}") - - # calculate stop encoder training - if stop_text_encoder_training_pct == None: - stop_text_encoder_training = 0 - else: - stop_text_encoder_training = math.ceil( - float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) - ) - log.info(f"stop_text_encoder_training = {stop_text_encoder_training}") - - lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) - log.info(f"lr_warmup_steps = {lr_warmup_steps}") - - run_cmd = "accelerate launch" - - run_cmd += run_cmd_advanced_training( - num_processes=num_processes, - num_machines=num_machines, - multi_gpu=multi_gpu, - gpu_ids=gpu_ids, - num_cpu_threads_per_process=num_cpu_threads_per_process, - ) - - if sdxl: - run_cmd += f' "./sdxl_train_textual_inversion.py"' - else: - run_cmd += f' "./train_textual_inversion.py"' - - run_cmd += run_cmd_advanced_training( - adaptive_noise_scale=adaptive_noise_scale, - additional_parameters=additional_parameters, - bucket_no_upscale=bucket_no_upscale, - bucket_reso_steps=bucket_reso_steps, - cache_latents=cache_latents, - cache_latents_to_disk=cache_latents_to_disk, - caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, - caption_extension=caption_extension, - clip_skip=clip_skip, - color_aug=color_aug, - enable_bucket=enable_bucket, - epoch=epoch, - flip_aug=flip_aug, - full_fp16=full_fp16, - gradient_accumulation_steps=gradient_accumulation_steps, - gradient_checkpointing=gradient_checkpointing, - keep_tokens=keep_tokens, - learning_rate=learning_rate, - logging_dir=logging_dir, - lr_scheduler=lr_scheduler, - lr_scheduler_args=lr_scheduler_args, - lr_scheduler_num_cycles=lr_scheduler_num_cycles, - lr_scheduler_power=lr_scheduler_power, - lr_warmup_steps=lr_warmup_steps, - max_bucket_reso=max_bucket_reso, - max_data_loader_n_workers=max_data_loader_n_workers, - max_resolution=max_resolution, - max_timestep=max_timestep, - max_token_length=max_token_length, - max_train_epochs=max_train_epochs, - max_train_steps=max_train_steps, - mem_eff_attn=mem_eff_attn, - min_bucket_reso=min_bucket_reso, - min_snr_gamma=min_snr_gamma, - min_timestep=min_timestep, - mixed_precision=mixed_precision, - multires_noise_discount=multires_noise_discount, - multires_noise_iterations=multires_noise_iterations, - no_half_vae=True if sdxl and sdxl_no_half_vae else None, - no_token_padding=no_token_padding, - noise_offset=noise_offset, - noise_offset_type=noise_offset_type, - optimizer=optimizer, - optimizer_args=optimizer_args, - output_dir=output_dir, - output_name=output_name, - persistent_data_loader_workers=persistent_data_loader_workers, - pretrained_model_name_or_path=pretrained_model_name_or_path, - prior_loss_weight=prior_loss_weight, - random_crop=random_crop, - reg_data_dir=reg_data_dir, - resume=resume, - save_every_n_epochs=save_every_n_epochs, - save_every_n_steps=save_every_n_steps, - save_last_n_steps=save_last_n_steps, - save_last_n_steps_state=save_last_n_steps_state, - save_model_as=save_model_as, - save_precision=save_precision, - save_state=save_state, - scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred, - seed=seed, - shuffle_caption=shuffle_caption, - stop_text_encoder_training=stop_text_encoder_training, - train_batch_size=train_batch_size, - train_data_dir=train_data_dir, - use_wandb=use_wandb, - v2=v2, - v_parameterization=v_parameterization, - v_pred_like_loss=v_pred_like_loss, - vae=vae, - vae_batch_size=vae_batch_size, - wandb_api_key=wandb_api_key, - xformers=xformers, - ) - run_cmd += f' --token_string="{token_string}"' - run_cmd += f' --init_word="{init_word}"' - run_cmd += f" --num_vectors_per_token={num_vectors_per_token}" - if not weights == "": - run_cmd += f' --weights="{weights}"' - if template == "object template": - run_cmd += f" --use_object_template" - elif template == "style template": - run_cmd += f" --use_style_template" - - run_cmd += run_cmd_sample( - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - output_dir, - ) - - if print_only_bool: - log.warning( - "Here is the trainer command as a reference. It will not be executed:\n" - ) - print(run_cmd) - - save_to_file(run_cmd) - else: - # Saving config file for model - current_datetime = datetime.now() - formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S") - file_path = os.path.join(output_dir, f"{output_name}_{formatted_datetime}.json") - - log.info(f"Saving training config to {file_path}...") - - SaveConfigFile( - parameters=parameters, - file_path=file_path, - exclusion=["file_path", "save_as", "headless", "print_only"], - ) - - log.info(run_cmd) - - # Run the command - - executor.execute_command(run_cmd=run_cmd) - - # check if output_dir/last is a folder... therefore it is a diffuser model - last_dir = pathlib.Path(f"{output_dir}/{output_name}") - - if not last_dir.is_dir(): - # Copy inference model for v2 if required - save_inference_file(output_dir, v2, v_parameterization, output_name) - - -def ti_tab( - headless=False, -): - dummy_db_true = gr.Label(value=True, visible=False) - dummy_db_false = gr.Label(value=False, visible=False) - dummy_headless = gr.Label(value=headless, visible=False) - - with gr.Tab("Training"): - gr.Markdown("Train a TI using kohya textual inversion python code...") - - # Setup Configuration Files Gradio - config = ConfigurationFile(headless) - - source_model = SourceModel( - save_model_as_choices=[ - "ckpt", - "safetensors", - ], - headless=headless, - ) - - with gr.Tab("Folders"): - folders = Folders(headless=headless) - with gr.Tab("Parameters"): - with gr.Tab("Basic", elem_id="basic_tab"): - with gr.Row(): - weights = gr.Textbox( - label='Resume TI training', - placeholder='(Optional) Path to existing TI embedding file to keep training', - ) - weights_file_input = gr.Button( - "", - elem_id="open_folder_small", - visible=(not headless), - ) - weights_file_input.click( - get_file_path, - outputs=weights, - show_progress=False, - ) - with gr.Row(): - token_string = gr.Textbox( - label="Token string", - placeholder="eg: cat", - ) - init_word = gr.Textbox( - label="Init word", - value="*", - ) - num_vectors_per_token = gr.Slider( - minimum=1, - maximum=75, - value=1, - step=1, - label="Vectors", - ) - # max_train_steps = gr.Textbox( - # label='Max train steps', - # placeholder='(Optional) Maximum number of steps', - # ) - template = gr.Dropdown( - label="Template", - choices=[ - "caption", - "object template", - "style template", - ], - value="caption", - ) - basic_training = BasicTraining( - learning_rate_value="1e-5", - lr_scheduler_value="cosine", - lr_warmup_value="10", - sdxl_checkbox=source_model.sdxl_checkbox, - ) - - # Add SDXL Parameters - sdxl_params = SDXLParameters( - source_model.sdxl_checkbox, - show_sdxl_cache_text_encoder_outputs=False, - ) - - with gr.Tab("Advanced", elem_id="advanced_tab"): - advanced_training = AdvancedTraining(headless=headless) - advanced_training.color_aug.change( - color_aug_changed, - inputs=[advanced_training.color_aug], - outputs=[basic_training.cache_latents], - ) - - with gr.Tab("Samples", elem_id="samples_tab"): - sample = SampleImages() - - with gr.Tab("Dataset Preparation"): - gr.Markdown( - "This section provide Dreambooth tools to help setup your dataset..." - ) - gradio_dreambooth_folder_creation_tab( - train_data_dir_input=folders.train_data_dir, - reg_data_dir_input=folders.reg_data_dir, - output_dir_input=folders.output_dir, - logging_dir_input=folders.logging_dir, - headless=headless, - ) - gradio_dataset_balancing_tab(headless=headless) - - with gr.Row(): - button_run = gr.Button("Start training", variant="primary") - - button_stop_training = gr.Button("Stop training") - - button_print = gr.Button("Print training command") - - # Setup gradio tensorboard buttons - ( - button_start_tensorboard, - button_stop_tensorboard, - ) = gradio_tensorboard() - - button_start_tensorboard.click( - start_tensorboard, - inputs=[dummy_headless, folders.logging_dir], - show_progress=False, - ) - - button_stop_tensorboard.click( - stop_tensorboard, - show_progress=False, - ) - - settings_list = [ - source_model.pretrained_model_name_or_path, - source_model.v2, - source_model.v_parameterization, - source_model.sdxl_checkbox, - folders.logging_dir, - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - basic_training.max_resolution, - basic_training.learning_rate, - basic_training.lr_scheduler, - basic_training.lr_warmup, - basic_training.train_batch_size, - basic_training.epoch, - basic_training.save_every_n_epochs, - basic_training.mixed_precision, - basic_training.save_precision, - basic_training.seed, - basic_training.num_cpu_threads_per_process, - basic_training.cache_latents, - basic_training.cache_latents_to_disk, - basic_training.caption_extension, - basic_training.enable_bucket, - advanced_training.gradient_checkpointing, - advanced_training.full_fp16, - advanced_training.no_token_padding, - basic_training.stop_text_encoder_training, - basic_training.min_bucket_reso, - basic_training.max_bucket_reso, - advanced_training.xformers, - source_model.save_model_as, - advanced_training.shuffle_caption, - advanced_training.save_state, - advanced_training.resume, - advanced_training.prior_loss_weight, - advanced_training.color_aug, - advanced_training.flip_aug, - advanced_training.clip_skip, - advanced_training.num_processes, - advanced_training.num_machines, - advanced_training.multi_gpu, - advanced_training.gpu_ids, - advanced_training.vae, - folders.output_name, - advanced_training.max_token_length, - basic_training.max_train_epochs, - advanced_training.max_data_loader_n_workers, - advanced_training.mem_eff_attn, - advanced_training.gradient_accumulation_steps, - source_model.model_list, - token_string, - init_word, - num_vectors_per_token, - basic_training.max_train_steps, - weights, - template, - advanced_training.keep_tokens, - basic_training.lr_scheduler_num_cycles, - basic_training.lr_scheduler_power, - advanced_training.persistent_data_loader_workers, - advanced_training.bucket_no_upscale, - advanced_training.random_crop, - advanced_training.bucket_reso_steps, - advanced_training.v_pred_like_loss, - advanced_training.caption_dropout_every_n_epochs, - advanced_training.caption_dropout_rate, - basic_training.optimizer, - basic_training.optimizer_args, - basic_training.lr_scheduler_args, - advanced_training.noise_offset_type, - advanced_training.noise_offset, - advanced_training.adaptive_noise_scale, - advanced_training.multires_noise_iterations, - advanced_training.multires_noise_discount, - sample.sample_every_n_steps, - sample.sample_every_n_epochs, - sample.sample_sampler, - sample.sample_prompts, - advanced_training.additional_parameters, - advanced_training.vae_batch_size, - advanced_training.min_snr_gamma, - advanced_training.save_every_n_steps, - advanced_training.save_last_n_steps, - advanced_training.save_last_n_steps_state, - advanced_training.use_wandb, - advanced_training.wandb_api_key, - advanced_training.scale_v_pred_loss_like_noise_pred, - advanced_training.min_timestep, - advanced_training.max_timestep, - sdxl_params.sdxl_no_half_vae, - ] - - config.button_open_config.click( - open_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) - - config.button_load_config.click( - open_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name] + settings_list, - show_progress=False, - ) - - config.button_save_config.click( - save_configuration, - inputs=[dummy_db_false, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - config.button_save_as_config.click( - save_configuration, - inputs=[dummy_db_true, config.config_file_name] + settings_list, - outputs=[config.config_file_name], - show_progress=False, - ) - - button_run.click( - train_model, - inputs=[dummy_headless] + [dummy_db_false] + settings_list, - show_progress=False, - ) - - button_stop_training.click(executor.kill_command) - - button_print.click( - train_model, - inputs=[dummy_headless] + [dummy_db_true] + settings_list, - show_progress=False, - ) - - return ( - folders.train_data_dir, - folders.reg_data_dir, - folders.output_dir, - folders.logging_dir, - ) +# Set up logging +log = setup_logging() def UI(**kwargs): diff --git a/library/utilities.py b/utilities_gui.py similarity index 53% rename from library/utilities.py rename to utilities_gui.py index 95d30b0ce..65cc066e5 100644 --- a/library/utilities.py +++ b/utilities_gui.py @@ -1,44 +1,15 @@ -# v1: initial release -# v2: add open and save folder icons -# v3: Add new Utilities tab for Dreambooth folder preparation -# v3.1: Adding captionning of images to utilities - +import argparse import gradio as gr import os -import argparse -from library.basic_caption_gui import gradio_basic_caption_gui_tab -from library.convert_model_gui import gradio_convert_model_tab -from library.blip_caption_gui import gradio_blip_caption_gui_tab -from library.git_caption_gui import gradio_git_caption_gui_tab -from library.wd14_caption_gui import gradio_wd14_caption_gui_tab -from library.manual_caption_gui import gradio_manual_caption_gui_tab -from library.group_images_gui import gradio_group_images_gui_tab +from kohya_gui.utilities import utilities_tab -def utilities_tab( - train_data_dir_input=gr.Textbox(), - reg_data_dir_input=gr.Textbox(), - output_dir_input=gr.Textbox(), - logging_dir_input=gr.Textbox(), - enable_copy_info_button=bool(False), - enable_dreambooth_tab=True, - headless=False -): - with gr.Tab('Captioning'): - gradio_basic_caption_gui_tab(headless=headless) - gradio_blip_caption_gui_tab(headless=headless) - gradio_git_caption_gui_tab(headless=headless) - gradio_wd14_caption_gui_tab(headless=headless) - gradio_manual_caption_gui_tab(headless=headless) - gradio_convert_model_tab(headless=headless) - gradio_group_images_gui_tab(headless=headless) +from kohya_gui.custom_logging import setup_logging +from kohya_gui.localization_ext import add_javascript - return ( - train_data_dir_input, - reg_data_dir_input, - output_dir_input, - logging_dir_input, - ) + +# Set up logging +log = setup_logging() def UI(**kwargs):