Skip to content

Commit

Permalink
Add support for custom LyCORIS preset config toml base files (#2425)
Browse files Browse the repository at this point in the history
* Add support for custom LyCORIS preset config toml base files
  • Loading branch information
bmaltais committed May 1, 2024
1 parent 91350e5 commit e836bb5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 477 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,4 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
### v24.1.0
- To ensure cross-platform compatibility and security, the GUI now defaults to using "shell=False" when running subprocesses. This is based on documentation and should not cause issues on most platforms. However, some users have reported issues on specific platforms such as runpod and colab. PLease open an issue if you encounter any issues.
- Add support for custom LyCORIS toml config files. Simply type the path to the config file in the LyCORIS preset dropdown.
75 changes: 50 additions & 25 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
validate_args_setting,
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -60,6 +60,15 @@

presets_dir = rf"{scriptdir}/presets"

LYCORIS_PRESETS_CHOICES = [
"attn-mlp",
"attn-only",
"full",
"full-lin",
"unet-transformer-only",
"unet-convblock-only",
]


def save_configuration(
save_as_bool,
Expand Down Expand Up @@ -667,13 +676,30 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
global train_state_value

TRAIN_BUTTON_VISIBLE = [
gr.Button(visible=True),
gr.Button(visible=False or headless),
gr.Textbox(value=train_state_value),
]

if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES:
if not os.path.exists(LyCORIS_preset):
output_message(
msg=f"LyCORIS preset file {LyCORIS_preset} does not exist.",
headless=headless,
)
return TRAIN_BUTTON_VISIBLE
else:
try:
toml.load(LyCORIS_preset)
except:
output_message(
msg=f"LyCORIS preset file {LyCORIS_preset} is not a valid toml file.",
headless=headless,
)
return TRAIN_BUTTON_VISIBLE

if executor.is_running():
log.error("Training is already running. Can't start another training session.")
return TRAIN_BUTTON_VISIBLE
Expand All @@ -683,7 +709,7 @@ def train_model(
log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return
Expand Down Expand Up @@ -954,7 +980,7 @@ def train_model(

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'
network_args += f" {key}={value}"

if LoRA_type in ["LoRA-FA"]:
kohya_lora_var_list = [
Expand Down Expand Up @@ -983,7 +1009,7 @@ def train_model(

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'
network_args += f" {key}={value}"

if LoRA_type in ["Kohya DyLoRA"]:
kohya_lora_var_list = [
Expand Down Expand Up @@ -1013,8 +1039,8 @@ def train_model(

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'
network_args += f" {key}={value}"

# Convert learning rates to float once and store the result for re-use
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
text_encoder_lr_float = (
Expand Down Expand Up @@ -1079,7 +1105,9 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_num_cycles": (
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
int(lr_scheduler_num_cycles)
if lr_scheduler_num_cycles != ""
else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
Expand All @@ -1088,7 +1116,9 @@ def train_model(
"max_grad_norm": max_grad_norm,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_epochs": (
int(max_train_epochs) if int(max_train_epochs) != 0 else None
),
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
Expand Down Expand Up @@ -1181,16 +1211,16 @@ def train_model(
for key, value in config_toml_data.items()
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))

current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
tmpfilename = f"./outputs/config_lora-{formatted_datetime}.toml"

# Save the updated TOML data back to the file
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
toml.dump(config_toml_data, toml_file)
Expand Down Expand Up @@ -1229,14 +1259,14 @@ def train_model(
# log.info(run_cmd)
env = os.environ.copy()
env["PYTHONPATH"] = (
fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command

executor.execute_command(run_cmd=run_cmd, env=env)

train_state_value = time.time()

return (
Expand Down Expand Up @@ -1357,17 +1387,11 @@ def list_presets(path):
)
LyCORIS_preset = gr.Dropdown(
label="LyCORIS Preset",
choices=[
"attn-mlp",
"attn-only",
"full",
"full-lin",
"unet-transformer-only",
"unet-convblock-only",
],
choices=LYCORIS_PRESETS_CHOICES,
value="full",
visible=False,
interactive=True,
allow_custom_value=True,
# info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md"
)
with gr.Group():
Expand Down Expand Up @@ -2102,7 +2126,7 @@ def update_LoRA_settings(

global executor
executor = CommandExecutor(headless=headless)

with gr.Column(), gr.Group():
with gr.Row():
button_print = gr.Button("Print training command")
Expand Down Expand Up @@ -2312,7 +2336,7 @@ def update_LoRA_settings(
)

run_state = gr.Textbox(value=train_state_value, visible=False)

run_state.change(
fn=executor.wait_for_training_to_end,
outputs=[executor.button_run, executor.button_stop_training],
Expand All @@ -2326,7 +2350,8 @@ def update_LoRA_settings(
)

executor.button_stop_training.click(
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
executor.kill_command,
outputs=[executor.button_run, executor.button_stop_training],
)

button_print.click(
Expand Down
Loading

0 comments on commit e836bb5

Please sign in to comment.