Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation of lr scheduler and optimizer arguments #2358

Merged
merged 1 commit into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
### 2024/04/220 (v24.0.6)

- Make start and stop buttons visible in headless
- Add validation for lr and optimizer arguments

### 2024/04/19 (v24.0.5)

Expand Down
15 changes: 14 additions & 1 deletion kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,4 +1498,17 @@ def print_command_and_toml(run_cmd, tmpfilename):
log.info(toml_file.read())
log.info(f"end of toml config file: {tmpfilename}")

save_to_file(command_to_run)
save_to_file(command_to_run)

def validate_args_setting(input_string):
# Regex pattern to handle multiple conditions:
# - Empty string is valid
# - Single or multiple key/value pairs with exactly one space between pairs
# - No spaces around '=' and no spaces within keys or values
pattern = r'^(\S+=\S+)( \S+=\S+)*$|^$'
if re.match(pattern, input_string):
return True
else:
log.info(f"'{input_string}' is not a valid settings string.")
log.info("A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space.")
return False
28 changes: 19 additions & 9 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting,
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -491,19 +492,27 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
global train_state_value

TRAIN_BUTTON_VISIBLE = [
gr.Button(visible=True),
gr.Button(visible=False or headless),
gr.Textbox(value=train_state_value),
]

if executor.is_running():
log.error("Training is already running. Can't start another training session.")
return TRAIN_BUTTON_VISIBLE

log.info(f"Start training Dreambooth...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

# This function validates files or folder paths. Simply add new variables containing file of folder path
# to validate below
if not validate_paths(
Expand Down Expand Up @@ -808,9 +817,9 @@ def train_model(
for key, value in config_toml_data.items()
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))

Expand Down Expand Up @@ -861,7 +870,7 @@ def train_model(
# Run the command

executor.execute_command(run_cmd=run_cmd, use_shell=use_shell, env=env)

train_state_value = time.time()

return (
Expand Down Expand Up @@ -950,7 +959,7 @@ def dreambooth_tab(

global executor
executor = CommandExecutor(headless=headless)

with gr.Column(), gr.Group():
with gr.Row():
button_print = gr.Button("Print training command")
Expand Down Expand Up @@ -1102,9 +1111,9 @@ def dreambooth_tab(
outputs=[configuration.config_file_name],
show_progress=False,
)

run_state = gr.Textbox(value=train_state_value, visible=False)

run_state.change(
fn=executor.wait_for_training_to_end,
outputs=[executor.button_run, executor.button_stop_training],
Expand All @@ -1118,7 +1127,8 @@ def dreambooth_tab(
)

executor.button_stop_training.click(
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
executor.kill_command,
outputs=[executor.button_run, executor.button_stop_training],
)

button_print.click(
Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -544,6 +545,14 @@ def train_model(

log.info(f"Start Finetuning...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

if train_dir != "" and not os.path.exists(train_dir):
os.mkdir(train_dir)

Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -679,6 +680,14 @@ def train_model(

log.info(f"Start training LoRA {LoRA_type} ...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
Expand Down
9 changes: 9 additions & 0 deletions kohya_gui/textual_inversion_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -505,6 +506,14 @@ def train_model(

log.info(f"Start training TI...")

log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return

if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
Expand Down