Skip to content

Commit

Permalink
Update run_cmd_training syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed May 7, 2023
1 parent b158cb3 commit fe874aa
Showing 1 changed file with 61 additions and 43 deletions.
104 changes: 61 additions & 43 deletions library/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,48 +821,66 @@ def gradio_training(


def run_cmd_training(**kwargs):
options = [
f' --learning_rate="{kwargs.get("learning_rate", "")}"'
if kwargs.get('learning_rate')
else '',
f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"'
if kwargs.get('lr_scheduler')
else '',
f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"'
if kwargs.get('lr_warmup_steps')
else '',
f' --train_batch_size="{kwargs.get("train_batch_size", "")}"'
if kwargs.get('train_batch_size')
else '',
f' --max_train_steps="{kwargs.get("max_train_steps", "")}"'
if kwargs.get('max_train_steps')
else '',
f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"'
if int(kwargs.get('save_every_n_epochs'))
else '',
f' --mixed_precision="{kwargs.get("mixed_precision", "")}"'
if kwargs.get('mixed_precision')
else '',
f' --save_precision="{kwargs.get("save_precision", "")}"'
if kwargs.get('save_precision')
else '',
f' --seed="{kwargs.get("seed", "")}"'
if kwargs.get('seed') != ''
else '',
f' --caption_extension="{kwargs.get("caption_extension", "")}"'
if kwargs.get('caption_extension')
else '',
' --cache_latents' if kwargs.get('cache_latents') else '',
' --cache_latents_to_disk'
if kwargs.get('cache_latents_to_disk')
else '',
# ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
f' --optimizer_args {kwargs.get("optimizer_args", "")}'
if not kwargs.get('optimizer_args') == ''
else '',
]
run_cmd = ''.join(options)
run_cmd = ''

learning_rate = kwargs.get("learning_rate", "")
if learning_rate:
run_cmd += f' --learning_rate="{learning_rate}"'

lr_scheduler = kwargs.get("lr_scheduler", "")
if lr_scheduler:
run_cmd += f' --lr_scheduler="{lr_scheduler}"'

lr_warmup_steps = kwargs.get("lr_warmup_steps", "")
if lr_warmup_steps:
if lr_scheduler == 'constant':
print('Can\'t use LR warmup with LR Scheduler constant... ignoring...')
else:
run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"'

train_batch_size = kwargs.get("train_batch_size", "")
if train_batch_size:
run_cmd += f' --train_batch_size="{train_batch_size}"'

max_train_steps = kwargs.get("max_train_steps", "")
if max_train_steps:
run_cmd += f' --max_train_steps="{max_train_steps}"'

save_every_n_epochs = kwargs.get("save_every_n_epochs")
if save_every_n_epochs:
run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"'

mixed_precision = kwargs.get("mixed_precision", "")
if mixed_precision:
run_cmd += f' --mixed_precision="{mixed_precision}"'

save_precision = kwargs.get("save_precision", "")
if save_precision:
run_cmd += f' --save_precision="{save_precision}"'

seed = kwargs.get("seed", "")
if seed != '':
run_cmd += f' --seed="{seed}"'

caption_extension = kwargs.get("caption_extension", "")
if caption_extension:
run_cmd += f' --caption_extension="{caption_extension}"'

cache_latents = kwargs.get('cache_latents')
if cache_latents:
run_cmd += ' --cache_latents'

cache_latents_to_disk = kwargs.get('cache_latents_to_disk')
if cache_latents_to_disk:
run_cmd += ' --cache_latents_to_disk'

optimizer_type = kwargs.get("optimizer", "AdamW")
run_cmd += f' --optimizer_type="{optimizer_type}"'

optimizer_args = kwargs.get("optimizer_args", "")
if optimizer_args != '':
run_cmd += f' --optimizer_args {optimizer_args}'

return run_cmd


Expand Down Expand Up @@ -1084,7 +1102,7 @@ def run_cmd_advanced_training(**kwargs):

max_train_epochs = kwargs.get("max_train_epochs", "")
if max_train_epochs:
run_cmd += ' --max_train_epochs={max_train_epochs}'
run_cmd += f' --max_train_epochs={max_train_epochs}'

max_data_loader_n_workers = kwargs.get("max_data_loader_n_workers", "")
if max_data_loader_n_workers:
Expand Down

0 comments on commit fe874aa

Please sign in to comment.