From d96df32983a37aabf7e85911ea28bb3c2dd27247 Mon Sep 17 00:00:00 2001 From: b-fission Date: Mon, 5 Aug 2024 23:14:40 -0500 Subject: [PATCH] bring back SDXLConfig accordion for dreambooth gui --- kohya_gui/dreambooth_gui.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index a38230a21..4e8e5eece 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -31,6 +31,7 @@ from .class_command_executor import CommandExecutor from .class_huggingface import HuggingFace from .class_metadata import MetaData +from .class_sdxl_parameters import SDXLParameters from .dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, @@ -162,6 +163,8 @@ def save_configuration( log_tracker_name, log_tracker_config, scale_v_pred_loss_like_noise_pred, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, min_timestep, max_timestep, debiased_estimation_loss, @@ -320,6 +323,8 @@ def open_configuration( log_tracker_name, log_tracker_config, scale_v_pred_loss_like_noise_pred, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, min_timestep, max_timestep, debiased_estimation_loss, @@ -473,6 +478,8 @@ def train_model( log_tracker_name, log_tracker_config, scale_v_pred_loss_like_noise_pred, + sdxl_cache_text_encoder_outputs, + sdxl_no_half_vae, min_timestep, max_timestep, debiased_estimation_loss, @@ -705,6 +712,9 @@ def train_model( else: run_cmd.append(rf"{scriptdir}/sd-scripts/train_db.py") + cache_text_encoder_outputs = sdxl and sdxl_cache_text_encoder_outputs + no_half_vae = sdxl and sdxl_no_half_vae + if max_data_loader_n_workers == "" or None: max_data_loader_n_workers = 0 else: @@ -724,6 +734,7 @@ def train_model( "bucket_reso_steps": bucket_reso_steps, "cache_latents": cache_latents, "cache_latents_to_disk": cache_latents_to_disk, + "cache_text_encoder_outputs": cache_text_encoder_outputs, "caption_dropout_every_n_epochs": int(caption_dropout_every_n_epochs), "caption_dropout_rate": caption_dropout_rate, "caption_extension": caption_extension, @@ -789,6 +800,7 @@ def train_model( "mixed_precision": mixed_precision, "multires_noise_discount": multires_noise_discount, "multires_noise_iterations": multires_noise_iterations if not 0 else None, + "no_half_vae": no_half_vae, "no_token_padding": no_token_padding, "noise_offset": noise_offset if not 0 else None, "noise_offset_random_strength": noise_offset_random_strength, @@ -981,6 +993,11 @@ def dreambooth_tab( config=config, ) + # Add SDXL Parameters + sdxl_params = SDXLParameters( + source_model.sdxl_checkbox, config=config + ) + with gr.Accordion("Advanced", open=False, elem_id="advanced_tab"): advanced_training = AdvancedTraining(headless=headless, config=config) advanced_training.color_aug.change( @@ -1112,6 +1129,8 @@ def dreambooth_tab( advanced_training.log_tracker_name, advanced_training.log_tracker_config, advanced_training.scale_v_pred_loss_like_noise_pred, + sdxl_params.sdxl_cache_text_encoder_outputs, + sdxl_params.sdxl_no_half_vae, advanced_training.min_timestep, advanced_training.max_timestep, advanced_training.debiased_estimation_loss,