Skip to content

Commit

Permalink
v18: Save model as option added
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Dec 18, 2022
1 parent fc22813 commit f459c32
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 93 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
- Lord of the universe - cacoe (twitter: @cac0e)

## Change history

* 12/17 (v18) update:
- Save model as option added to train_db_fixed.py
- Save model as option added to GUI
- Retire "Model conversion" parameters that was essentially performing the same function as the new `--save_model_as` parameter
* 12/17 (v17.2) update:
- Adding new dataset balancing utility.
* 12/17 (v17.1) update:
Expand Down
111 changes: 39 additions & 72 deletions dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,16 @@ def save_configuration(
save_precision,
seed,
num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent,
caption_extention,
use_safetensors,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as
):
original_file_path = file_path

Expand Down Expand Up @@ -103,18 +101,16 @@ def save_configuration(
'save_precision': save_precision,
'seed': seed,
'num_cpu_threads_per_process': num_cpu_threads_per_process,
'convert_to_safetensors': convert_to_safetensors,
'convert_to_ckpt': convert_to_ckpt,
'cache_latent': cache_latent,
'caption_extention': caption_extention,
'use_safetensors': use_safetensors,
'enable_bucket': enable_bucket,
'gradient_checkpointing': gradient_checkpointing,
'full_fp16': full_fp16,
'no_token_padding': no_token_padding,
'stop_text_encoder_training': stop_text_encoder_training,
'use_8bit_adam': use_8bit_adam,
'xformers': xformers,
'save_model_as': save_model_as
}

# Save the data to the selected file
Expand Down Expand Up @@ -144,18 +140,16 @@ def open_configuration(
save_precision,
seed,
num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent,
caption_extention,
use_safetensors,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training,
use_8bit_adam,
xformers,
save_model_as
):

original_file_path = file_path
Expand Down Expand Up @@ -195,18 +189,16 @@ def open_configuration(
my_data.get(
'num_cpu_threads_per_process', num_cpu_threads_per_process
),
my_data.get('convert_to_safetensors', convert_to_safetensors),
my_data.get('convert_to_ckpt', convert_to_ckpt),
my_data.get('cache_latent', cache_latent),
my_data.get('caption_extention', caption_extention),
my_data.get('use_safetensors', use_safetensors),
my_data.get('enable_bucket', enable_bucket),
my_data.get('gradient_checkpointing', gradient_checkpointing),
my_data.get('full_fp16', full_fp16),
my_data.get('no_token_padding', no_token_padding),
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
my_data.get('use_8bit_adam', use_8bit_adam),
my_data.get('xformers', xformers),
my_data.get('save_model_as', save_model_as)
)


Expand All @@ -229,18 +221,16 @@ def train_model(
save_precision,
seed,
num_cpu_threads_per_process,
convert_to_safetensors,
convert_to_ckpt,
cache_latent,
caption_extention,
use_safetensors,
enable_bucket,
gradient_checkpointing,
full_fp16,
no_token_padding,
stop_text_encoder_training_pct,
use_8bit_adam,
xformers,
save_model_as
):
def save_inference_file(output_dir, v2, v_parameterization):
# Copy inference model for v2 if required
Expand Down Expand Up @@ -352,8 +342,6 @@ def save_inference_file(output_dir, v2, v_parameterization):
run_cmd += ' --v_parameterization'
if cache_latent:
run_cmd += ' --cache_latents'
if use_safetensors:
run_cmd += ' --use_safetensors'
if enable_bucket:
run_cmd += ' --enable_bucket'
if gradient_checkpointing:
Expand Down Expand Up @@ -388,39 +376,20 @@ def save_inference_file(output_dir, v2, v_parameterization):
run_cmd += f' --logging_dir={logging_dir}'
run_cmd += f' --caption_extention={caption_extention}'
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
if not save_model_as == 'same as source model':
run_cmd += f' --save_model_as={save_model_as}'

print(run_cmd)
# Run the command
subprocess.run(run_cmd)

# check if output_dir/last is a directory... therefore it is a diffuser model
# check if output_dir/last is a folder... therefore it is a diffuser model
last_dir = pathlib.Path(f'{output_dir}/last')
print(last_dir)
if last_dir.is_dir():
if convert_to_ckpt:
print(f'Converting diffuser model {last_dir} to {last_dir}.ckpt')
os.system(
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.ckpt --{save_precision}'
)

save_inference_file(output_dir, v2, v_parameterization)

if convert_to_safetensors:
print(
f'Converting diffuser model {last_dir} to {last_dir}.safetensors'
)
os.system(
f'python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}'
)

save_inference_file(output_dir, v2, v_parameterization)
else:

if not last_dir.is_dir():
# Copy inference model for v2 if required
save_inference_file(output_dir, v2, v_parameterization)

# Return the values of the variables as a dictionary
# return


def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# define a list of substrings to search for
Expand Down Expand Up @@ -533,6 +502,17 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
'CompVis/stable-diffusion-v1-4',
],
)
save_model_as_dropdown = gr.Dropdown(
label='Save trained model as',
choices=[
'same as source model',
'ckpt',
'diffusers',
"diffusers_safetensors",
'safetensors',
],
value='same as source model'
)
with gr.Row():
v2_input = gr.Checkbox(label='v2', value=True)
v_parameterization_input = gr.Checkbox(
Expand All @@ -557,7 +537,7 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
with gr.Row():
train_data_dir_input = gr.Textbox(
label='Image folder',
placeholder='Directory where the training folders containing the images are located',
placeholder='Folder where the training folders containing the images are located',
)
train_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
Expand All @@ -567,7 +547,7 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
)
reg_data_dir_input = gr.Textbox(
label='Regularisation folder',
placeholder='(Optional) Directory where where the regularization folders containing the images are located',
placeholder='(Optional) Folder where where the regularization folders containing the images are located',
)
reg_data_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
Expand All @@ -577,8 +557,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
)
with gr.Row():
output_dir_input = gr.Textbox(
label='Output directory',
placeholder='Directory to output trained model',
label='Output folder',
placeholder='Folder to output trained model',
)
output_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
Expand All @@ -587,8 +567,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
get_folder_path, outputs=output_dir_input
)
logging_dir_input = gr.Textbox(
label='Logging directory',
placeholder='Optional: enable logging and output TensorBoard log to this directory',
label='Logging folder',
placeholder='Optional: enable logging and output TensorBoard log to this folder',
)
logging_dir_input_folder = gr.Button(
'📂', elem_id='open_folder_small'
Expand Down Expand Up @@ -694,9 +674,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
no_token_padding_input = gr.Checkbox(
label='No token padding', value=False
)
use_safetensors_input = gr.Checkbox(
label='Use safetensor when saving', value=False
)

gradient_checkpointing_input = gr.Checkbox(
label='Gradient checkpointing', value=False
Expand All @@ -711,13 +688,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
)
xformers_input = gr.Checkbox(label='Use xformers', value=True)

with gr.Tab('Model conversion'):
convert_to_safetensors_input = gr.Checkbox(
label='Convert to SafeTensors', value=True
)
convert_to_ckpt_input = gr.Checkbox(
label='Convert to CKPT', value=False
)
with gr.Tab('Utilities'):
# Dreambooth folder creation tab
gradio_dreambooth_folder_creation_tab(
Expand All @@ -729,6 +699,13 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
# Captionning tab
gradio_caption_gui_tab()
gradio_dataset_balancing_tab()
# with gr.Tab('Model conversion'):
# convert_to_safetensors_input = gr.Checkbox(
# label='Convert to SafeTensors', value=True
# )
# convert_to_ckpt_input = gr.Checkbox(
# label='Convert to CKPT', value=False
# )

button_run = gr.Button('Train model')

Expand All @@ -754,18 +731,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
],
outputs=[
config_file_name,
Expand All @@ -787,18 +762,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
],
)

Expand Down Expand Up @@ -827,18 +800,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
],
outputs=[config_file_name],
)
Expand Down Expand Up @@ -866,18 +837,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
],
outputs=[config_file_name],
)
Expand All @@ -903,18 +872,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
save_precision_input,
seed_input,
num_cpu_threads_per_process_input,
convert_to_safetensors_input,
convert_to_ckpt_input,
cache_latent_input,
caption_extention_input,
use_safetensors_input,
enable_bucket_input,
gradient_checkpointing_input,
full_fp16_input,
no_token_padding_input,
stop_text_encoder_training_input,
use_8bit_adam_input,
xformers_input,
save_model_as_dropdown
],
)

Expand Down
Loading

0 comments on commit f459c32

Please sign in to comment.