Skip to content

Commit

Permalink
Fix bugs, improve layout logic
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Apr 11, 2024
1 parent 48a3242 commit 8633484
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 137 deletions.
13 changes: 1 addition & 12 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,4 @@ def UI(**kwargs):

args = parser.parse_args()

UI(
config_file_path=args.config,
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
share=args.share,
do_not_share=args.do_not_share,
listen=args.listen,
headless=args.headless,
language=args.language,
)
UI(**vars(args))
42 changes: 20 additions & 22 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,19 +435,16 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())

print_only_bool = True if print_only.get("label") == "True" else False
log.info(f"Start training Dreambooth...")

headless_bool = True if headless.get("label") == "True" else False

# This function validates files or folder paths. Simply add new variables containing file of folder path
# to validate below
if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
train_data_dir=train_data_dir,
reg_data_dir=reg_data_dir,
headless=headless_bool,
headless=headless,
logging_dir=logging_dir,
log_tracker_config=log_tracker_config,
resume=resume,
Expand All @@ -456,8 +453,8 @@ def train_model(
):
return

if not print_only_bool and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless_bool
if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless
):
return

Expand Down Expand Up @@ -693,7 +690,7 @@ def train_model(
output_dir,
)

if print_only_bool:
if print_only:
log.warning(
"Here is the trainer command as a reference. It will not be executed:\n"
)
Expand Down Expand Up @@ -738,7 +735,7 @@ def dreambooth_tab(
):
dummy_db_true = gr.Checkbox(value=True, visible=False)
dummy_db_false = gr.Checkbox(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
dummy_headless = gr.Checkbox(value=headless, visible=False)

with gr.Tab("Training"), gr.Column(variant="compact"):
gr.Markdown("Train a custom model using kohya dreambooth python code...")
Expand All @@ -756,6 +753,21 @@ def dreambooth_tab(
with gr.Accordion("Folders", open=False), gr.Group():
folders = Folders(headless=headless, config=config)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=source_model.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
config=config,
)

gradio_dataset_balancing_tab(headless=headless)

with gr.Accordion("Parameters", open=False), gr.Column():
with gr.Accordion("Basic", open="True"):
with gr.Group(elem_id="basic_tab"):
Expand All @@ -779,20 +791,6 @@ def dreambooth_tab(
with gr.Accordion("Samples", open=False, elem_id="samples_tab"):
sample = SampleImages(config=config)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=source_model.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
config=config,
)
gradio_dataset_balancing_tab(headless=headless)

with gr.Column(), gr.Group():
with gr.Row():
button_run = gr.Button("Start training", variant="primary")
Expand Down
101 changes: 50 additions & 51 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,29 +469,28 @@ def train_model(
):
# Get list of function parameters and values
parameters = list(locals().items())

print_only_bool = True if print_only.get("label") == "True" else False

log.debug(f"headless = {headless} ; print_only = {print_only}")

log.info(f"Start Finetuning...")

headless_bool = True if headless.get("label") == "True" else False

if train_dir != "" and not os.path.exists(train_dir):
os.mkdir(train_dir)

if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
finetune_image_folder=image_folder,
headless=headless_bool,
headless=headless,
logging_dir=logging_dir,
log_tracker_config=log_tracker_config,
resume=resume,
dataset_config=dataset_config,
):
return

if not print_only_bool and check_if_model_exist(
output_name, output_dir, save_model_as, headless_bool
if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless
):
return

Expand Down Expand Up @@ -519,7 +518,7 @@ def train_model(
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)

if not print_only_bool:
if not print_only:
# Run the command
subprocess.run(run_cmd, shell=True, env=env)

Expand Down Expand Up @@ -552,7 +551,7 @@ def train_model(
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)

if not print_only_bool:
if not print_only:
# Run the command
subprocess.run(run_cmd, shell=True, env=env)

Expand Down Expand Up @@ -724,7 +723,7 @@ def train_model(
output_dir,
)

if print_only_bool:
if print_only:
log.warning(
"Here is the trainer command as a reference. It will not be executed:\n"
)
Expand Down Expand Up @@ -761,7 +760,7 @@ def train_model(
def finetune_tab(headless=False, config: dict = {}):
dummy_db_true = gr.Checkbox(value=True, visible=False)
dummy_db_false = gr.Checkbox(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
dummy_headless = gr.Checkbox(value=headless, visible=False)
with gr.Tab("Training"), gr.Column(variant="compact"):
gr.Markdown("Train a custom model using kohya finetune python code...")

Expand All @@ -785,6 +784,46 @@ def finetune_tab(headless=False, config: dict = {}):
logging_dir = folders.logging_dir
train_dir = folders.reg_data_dir

with gr.Accordion("Dataset Preparation", open=False):
with gr.Row():
max_resolution = gr.Textbox(
label="Resolution (width,height)", value="512,512"
)
min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256")
max_bucket_reso = gr.Textbox(
label="Max bucket resolution", value="1024"
)
batch_size = gr.Textbox(label="Batch size", value="1")
with gr.Row():
create_caption = gr.Checkbox(
label="Generate caption metadata", value=True
)
create_buckets = gr.Checkbox(
label="Generate image buckets metadata", value=True
)
use_latent_files = gr.Dropdown(
label="Use latent files",
choices=[
"No",
"Yes",
],
value="Yes",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
caption_metadata_filename = gr.Textbox(
label="Caption metadata filename",
value="meta_cap.json",
)
latent_metadata_filename = gr.Textbox(
label="Latent metadata filename", value="meta_lat.json"
)
with gr.Row():
full_path = gr.Checkbox(label="Use full path", value=True)
weighted_captions = gr.Checkbox(
label="Weighted captions", value=False
)

with gr.Accordion("Parameters", open=False), gr.Column():

def list_presets(path):
Expand Down Expand Up @@ -853,46 +892,6 @@ def list_presets(path):
with gr.Accordion("Samples", open=False, elem_id="samples_tab"):
sample = SampleImages(config=config)

with gr.Accordion("Dataset Preparation", open=False):
with gr.Row():
max_resolution = gr.Textbox(
label="Resolution (width,height)", value="512,512"
)
min_bucket_reso = gr.Textbox(label="Min bucket resolution", value="256")
max_bucket_reso = gr.Textbox(
label="Max bucket resolution", value="1024"
)
batch_size = gr.Textbox(label="Batch size", value="1")
with gr.Row():
create_caption = gr.Checkbox(
label="Generate caption metadata", value=True
)
create_buckets = gr.Checkbox(
label="Generate image buckets metadata", value=True
)
use_latent_files = gr.Dropdown(
label="Use latent files",
choices=[
"No",
"Yes",
],
value="Yes",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
caption_metadata_filename = gr.Textbox(
label="Caption metadata filename",
value="meta_cap.json",
)
latent_metadata_filename = gr.Textbox(
label="Latent metadata filename", value="meta_lat.json"
)
with gr.Row():
full_path = gr.Checkbox(label="Use full path", value=True)
weighted_captions = gr.Checkbox(
label="Weighted captions", value=False
)

with gr.Column(), gr.Group():
with gr.Row():
button_run = gr.Button("Start training", variant="primary")
Expand Down
49 changes: 24 additions & 25 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,16 +633,14 @@ def train_model(
parameters = list(locals().items())
global command_running

print_only_bool = True if print_only.get("label") == "True" else False
log.info(f"Start training LoRA {LoRA_type} ...")
headless_bool = True if headless.get("label") == "True" else False

if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
train_data_dir=train_data_dir,
reg_data_dir=reg_data_dir,
headless=headless_bool,
headless=headless,
logging_dir=logging_dir,
log_tracker_config=log_tracker_config,
resume=resume,
Expand All @@ -655,7 +653,7 @@ def train_model(
if int(bucket_reso_steps) < 1:
output_message(
msg="Bucket resolution steps need to be greater than 0",
headless=headless_bool,
headless=headless,
)
return

Expand All @@ -665,7 +663,7 @@ def train_model(
if float(noise_offset) > 1 or float(noise_offset) < 0:
output_message(
msg="Noise offset need to be a value between 0 and 1",
headless=headless_bool,
headless=headless,
)
return

Expand All @@ -676,12 +674,12 @@ def train_model(
if stop_text_encoder_training_pct > 0:
output_message(
msg='Output "stop text encoder training" is not yet supported. Ignoring',
headless=headless_bool,
headless=headless,
)
stop_text_encoder_training_pct = 0

if not print_only_bool and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless_bool
if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless
):
return

Expand Down Expand Up @@ -913,7 +911,7 @@ def train_model(
# Determine the training configuration based on learning rate values
# Sets flags for training specific components based on the provided learning rates.
if float(learning_rate) == unet_lr_float == text_encoder_lr_float == 0:
output_message(msg="Please input learning rate values.", headless=headless_bool)
output_message(msg="Please input learning rate values.", headless=headless)
return
# Flag to train text encoder only if its learning rate is non-zero and unet's is zero.
network_train_text_encoder_only = text_encoder_lr_float != 0 and unet_lr_float == 0
Expand Down Expand Up @@ -1041,7 +1039,7 @@ def train_model(
output_dir,
)

if print_only_bool:
if print_only:
log.warning(
"Here is the trainer command as a reference. It will not be executed:\n"
)
Expand Down Expand Up @@ -1084,7 +1082,7 @@ def lora_tab(
):
dummy_db_true = gr.Checkbox(value=True, visible=False)
dummy_db_false = gr.Checkbox(value=False, visible=False)
dummy_headless = gr.Label(value=headless, visible=False)
dummy_headless = gr.Checkbox(value=headless, visible=False)

with gr.Tab("Training"), gr.Column(variant="compact") as tab:
gr.Markdown(
Expand All @@ -1111,6 +1109,21 @@ def lora_tab(
with gr.Accordion("Folders", open=False), gr.Group():
folders = Folders(headless=headless, config=config)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=source_model.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
config=config,
)

gradio_dataset_balancing_tab(headless=headless)

with gr.Accordion("Parameters", open=False), gr.Column():

def list_presets(path):
Expand Down Expand Up @@ -1900,20 +1913,6 @@ def update_LoRA_settings(
],
)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
)
gradio_dreambooth_folder_creation_tab(
train_data_dir_input=source_model.train_data_dir,
reg_data_dir_input=folders.reg_data_dir,
output_dir_input=folders.output_dir,
logging_dir_input=folders.logging_dir,
headless=headless,
config=config,
)
gradio_dataset_balancing_tab(headless=headless)

with gr.Column(), gr.Group():
with gr.Row():
button_run = gr.Button("Start training", variant="primary")
Expand Down
Loading

0 comments on commit 8633484

Please sign in to comment.