Skip to content

Commit

Permalink
Add support for:
Browse files Browse the repository at this point in the history
shuffle_caption,
save_state,
resume,
prior_loss_weight,

Fix issue with config open and save
  • Loading branch information
bmaltais committed Dec 20, 2022
1 parent 1d41272 commit 1f1dd5c
Show file tree
Hide file tree
Showing 7 changed files with 1,636 additions and 1,115 deletions.
167 changes: 117 additions & 50 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import subprocess
import pathlib
import shutil
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
from library.dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
)
from library.basic_caption_gui import gradio_basic_caption_gui_tab
from library.convert_model_gui import gradio_convert_model_tab
from library.blip_caption_gui import gradio_blip_caption_gui_tab
Expand All @@ -20,14 +22,14 @@
get_folder_path,
remove_doublequote,
get_file_path,
get_saveasfile_path
get_saveasfile_path,
)
from easygui import msgbox

folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
document_symbol = '\U0001F4C4' # 📄


def save_configuration(
Expand Down Expand Up @@ -60,30 +62,26 @@ def save_configuration(
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as
save_model_as,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
):
original_file_path = file_path

save_as_bool = True if save_as.get('label') == 'True' else False

if save_as_bool:
print('Save as...')
# file_path = filesavebox(
# 'Select the config file to save',
# default='finetune.json',
# filetypes='*.json',
# )
file_path = get_saveasfile_path(file_path)
else:
print('Save...')
if file_path == None or file_path == '':
# file_path = filesavebox(
# 'Select the config file to save',
# default='finetune.json',
# filetypes='*.json',
# )
file_path = get_saveasfile_path(file_path)

# print(file_path)

if file_path == None or file_path == '':
return original_file_path # In case a file_path was provided and the user decide to cancel the open action

Expand Down Expand Up @@ -116,7 +114,11 @@ def save_configuration(
'stop_text_encoder_training': stop_text_encoder_training,
'use_8bit_adam': use_8bit_adam,
'xformers': xformers,
'save_model_as': save_model_as
'save_model_as': save_model_as,
'shuffle_caption': shuffle_caption,
'save_state': save_state,
'resume': resume,
'prior_loss_weight': prior_loss_weight,
}

# Save the data to the selected file
Expand Down Expand Up @@ -155,14 +157,18 @@ def open_configuration(
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as
save_model_as,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
):

original_file_path = file_path
file_path = get_file_path(file_path)
# print(file_path)

if file_path != '' and file_path != None:
print(file_path)
if not file_path == '' and not file_path == None:
# load variables from JSON file
with open(file_path, 'r') as f:
my_data = json.load(f)
Expand Down Expand Up @@ -204,7 +210,11 @@ def open_configuration(
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers),
my_data.get('save_model_as', save_model_as)
my_data.get('save_model_as', save_model_as),
my_data.get('shuffle_caption', shuffle_caption),
my_data.get('save_state', save_state),
my_data.get('resume', resume),
my_data.get('prior_loss_weight', prior_loss_weight),
)


Expand Down Expand Up @@ -236,7 +246,11 @@ def train_model(
stop_text_encoder_training_pct,
use_8bit_adam,
xformers,
save_model_as
save_model_as,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
Expand Down Expand Up @@ -360,6 +374,10 @@ def save_inference_file(output_dir, v2, v_parameterization):
run_cmd += ' --use_8bit_adam'
if xformers:
run_cmd += ' --xformers'
if shuffle_caption:
run_cmd += ' --shuffle_caption'
if save_state:
run_cmd += ' --save_state'
run_cmd += (
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
)
Expand All @@ -382,17 +400,23 @@ def save_inference_file(output_dir, v2, v_parameterization):
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --caption_extention={caption_extention}'
if not stop_text_encoder_training == 0:
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
run_cmd += (
f' --stop_text_encoder_training={stop_text_encoder_training}'
)
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'
if not resume == '':
run_cmd += f' --resume={resume}'
if not float(prior_loss_weight) == 1.0:
run_cmd += f' --prior_loss_weight={prior_loss_weight}'

print(run_cmd)
# Run the command
subprocess.run(run_cmd)

# check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/last')

if not last_dir.is_dir():
# Copy inference model for v2 if required
save_inference_file(output_dir, v2, v_parameterization)
Expand Down Expand Up @@ -472,8 +496,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
)
config_file_name = gr.Textbox(
label='',
# placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=False
placeholder="type the configuration file path or use the 'Open' button above to select it...",
interactive=True,
)
# config_file_name.change(
# remove_doublequote,
Expand All @@ -491,13 +515,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
document_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_fille.click(
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
get_file_path,
inputs=[pretrained_model_name_or_path_input],
outputs=pretrained_model_name_or_path_input,
)
pretrained_model_name_or_path_folder = gr.Button(
folder_symbol, elem_id='open_folder_small'
)
pretrained_model_name_or_path_folder.click(
get_folder_path, outputs=pretrained_model_name_or_path_input
get_folder_path,
outputs=pretrained_model_name_or_path_input,
)
model_list = gr.Dropdown(
label='(Optional) Model Quick Pick',
Expand All @@ -517,10 +544,10 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
'same as source model',
'ckpt',
'diffusers',
"diffusers_safetensors",
'diffusers_safetensors',
'safetensors',
],
value='same as source model'
value='same as source model',
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
Expand Down Expand Up @@ -607,7 +634,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
)
with gr.Tab('Training parameters'):
with gr.Row():
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
learning_rate_input = gr.Textbox(
label='Learning rate', value=1e-6
)
lr_scheduler_input = gr.Dropdown(
label='LR Scheduler',
choices=[
Expand Down Expand Up @@ -662,7 +691,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
with gr.Row():
seed_input = gr.Textbox(label='Seed', value=1234)
max_resolution_input = gr.Textbox(
label='Max resolution', value='512,512', placeholder='512,512'
label='Max resolution',
value='512,512',
placeholder='512,512',
)
with gr.Row():
caption_extention_input = gr.Textbox(
Expand All @@ -676,27 +707,45 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
step=1,
label='Stop text encoder training',
)
with gr.Row():
full_fp16_input = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
no_token_padding_input = gr.Checkbox(
label='No token padding', value=False
)

gradient_checkpointing_input = gr.Checkbox(
label='Gradient checkpointing', value=False
)
with gr.Row():
enable_bucket_input = gr.Checkbox(
label='Enable buckets', value=True
)
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
cache_latent_input = gr.Checkbox(
label='Cache latent', value=True
)
use_8bit_adam_input = gr.Checkbox(
label='Use 8bit adam', value=True
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)

with gr.Accordion('Advanced Configuration', open=False):
with gr.Row():
full_fp16_input = gr.Checkbox(
label='Full fp16 training (experimental)', value=False
)
no_token_padding_input = gr.Checkbox(
label='No token padding', value=False
)

gradient_checkpointing_input = gr.Checkbox(
label='Gradient checkpointing', value=False
)

shuffle_caption = gr.Checkbox(
label='Shuffle caption', value=False
)
save_state = gr.Checkbox(label='Save state', value=False)
with gr.Row():
resume = gr.Textbox(
label='Resume',
placeholder='path to "last-state" state folder to resume from',
)
resume_button = gr.Button('📂', elem_id='open_folder_small')
resume_button.click(get_folder_path, outputs=resume)
prior_loss_weight = gr.Number(
label='Prior loss weight', value=1.0
)

button_run = gr.Button('Train model')

with gr.Tab('Utilities'):
Expand All @@ -713,8 +762,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
gradio_dataset_balancing_tab()
gradio_convert_model_tab()



button_open_config.click(
open_configuration,
inputs=[
Expand Down Expand Up @@ -746,7 +793,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[
config_file_name,
Expand Down Expand Up @@ -777,7 +828,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
)

Expand Down Expand Up @@ -815,7 +870,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[config_file_name],
)
Expand Down Expand Up @@ -852,7 +911,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
outputs=[config_file_name],
)
Expand Down Expand Up @@ -887,7 +950,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
save_model_as_dropdown,
shuffle_caption,
save_state,
resume,
prior_loss_weight,
],
)

Expand Down
Loading

0 comments on commit 1f1dd5c

Please sign in to comment.