Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom LyCORIS preset config toml base files #2425

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,4 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
### v24.1.0

- To ensure cross-platform compatibility and security, the GUI now defaults to using "shell=False" when running subprocesses. This is based on documentation and should not cause issues on most platforms. However, some users have reported issues on specific platforms such as runpod and colab. PLease open an issue if you encounter any issues.
- Add support for custom LyCORIS toml config files. Simply type the path to the config file in the LyCORIS preset dropdown.
75 changes: 50 additions & 25 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
scriptdir,
update_my_data,
validate_paths,
validate_args_setting
validate_args_setting,
)
from .class_accelerate_launch import AccelerateLaunch
from .class_configuration_file import ConfigurationFile
Expand Down Expand Up @@ -60,6 +60,15 @@

presets_dir = rf"{scriptdir}/presets"

LYCORIS_PRESETS_CHOICES = [
"attn-mlp",
"attn-only",
"full",
"full-lin",
"unet-transformer-only",
"unet-convblock-only",
]


def save_configuration(
save_as_bool,
Expand Down Expand Up @@ -667,13 +676,30 @@ def train_model(
# Get list of function parameters and values
parameters = list(locals().items())
global train_state_value

TRAIN_BUTTON_VISIBLE = [
gr.Button(visible=True),
gr.Button(visible=False or headless),
gr.Textbox(value=train_state_value),
]

if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES:
if not os.path.exists(LyCORIS_preset):
output_message(
msg=f"LyCORIS preset file {LyCORIS_preset} does not exist.",
headless=headless,
)
return TRAIN_BUTTON_VISIBLE
else:
try:
toml.load(LyCORIS_preset)
except:
output_message(
msg=f"LyCORIS preset file {LyCORIS_preset} is not a valid toml file.",
headless=headless,
)
return TRAIN_BUTTON_VISIBLE

if executor.is_running():
log.error("Training is already running. Can't start another training session.")
return TRAIN_BUTTON_VISIBLE
Expand All @@ -683,7 +709,7 @@ def train_model(
log.info(f"Validating lr scheduler arguments...")
if not validate_args_setting(lr_scheduler_args):
return

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return
Expand Down Expand Up @@ -954,7 +980,7 @@ def train_model(

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'
network_args += f" {key}={value}"

if LoRA_type in ["LoRA-FA"]:
kohya_lora_var_list = [
Expand Down Expand Up @@ -983,7 +1009,7 @@ def train_model(

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'
network_args += f" {key}={value}"

if LoRA_type in ["Kohya DyLoRA"]:
kohya_lora_var_list = [
Expand Down Expand Up @@ -1013,8 +1039,8 @@ def train_model(

for key, value in kohya_lora_vars.items():
if value:
network_args += f' {key}={value}'
network_args += f" {key}={value}"

# Convert learning rates to float once and store the result for re-use
learning_rate = float(learning_rate) if learning_rate is not None else 0.0
text_encoder_lr_float = (
Expand Down Expand Up @@ -1079,7 +1105,9 @@ def train_model(
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": str(lr_scheduler_args).replace('"', "").split(),
"lr_scheduler_num_cycles": (
int(lr_scheduler_num_cycles) if lr_scheduler_num_cycles != "" else int(epoch)
int(lr_scheduler_num_cycles)
if lr_scheduler_num_cycles != ""
else int(epoch)
),
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
Expand All @@ -1088,7 +1116,9 @@ def train_model(
"max_grad_norm": max_grad_norm,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": int(max_train_epochs) if int(max_train_epochs) != 0 else None,
"max_train_epochs": (
int(max_train_epochs) if int(max_train_epochs) != 0 else None
),
"max_train_steps": int(max_train_steps) if int(max_train_steps) != 0 else None,
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
Expand Down Expand Up @@ -1181,16 +1211,16 @@ def train_model(
for key, value in config_toml_data.items()
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = int(max_data_loader_n_workers)

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))

current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d-%H%M%S")
tmpfilename = f"./outputs/config_lora-{formatted_datetime}.toml"

# Save the updated TOML data back to the file
with open(tmpfilename, "w", encoding="utf-8") as toml_file:
toml.dump(config_toml_data, toml_file)
Expand Down Expand Up @@ -1229,14 +1259,14 @@ def train_model(
# log.info(run_cmd)
env = os.environ.copy()
env["PYTHONPATH"] = (
fr"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
rf"{scriptdir}{os.pathsep}{scriptdir}/sd-scripts{os.pathsep}{env.get('PYTHONPATH', '')}"
)
env["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Run the command

executor.execute_command(run_cmd=run_cmd, env=env)

train_state_value = time.time()

return (
Expand Down Expand Up @@ -1357,17 +1387,11 @@ def list_presets(path):
)
LyCORIS_preset = gr.Dropdown(
label="LyCORIS Preset",
choices=[
"attn-mlp",
"attn-only",
"full",
"full-lin",
"unet-transformer-only",
"unet-convblock-only",
],
choices=LYCORIS_PRESETS_CHOICES,
value="full",
visible=False,
interactive=True,
allow_custom_value=True,
# info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md"
)
with gr.Group():
Expand Down Expand Up @@ -2102,7 +2126,7 @@ def update_LoRA_settings(

global executor
executor = CommandExecutor(headless=headless)

with gr.Column(), gr.Group():
with gr.Row():
button_print = gr.Button("Print training command")
Expand Down Expand Up @@ -2312,7 +2336,7 @@ def update_LoRA_settings(
)

run_state = gr.Textbox(value=train_state_value, visible=False)

run_state.change(
fn=executor.wait_for_training_to_end,
outputs=[executor.button_run, executor.button_stop_training],
Expand All @@ -2326,7 +2350,8 @@ def update_LoRA_settings(
)

executor.button_stop_training.click(
executor.kill_command, outputs=[executor.button_run, executor.button_stop_training]
executor.kill_command,
outputs=[executor.button_run, executor.button_stop_training],
)

button_print.click(
Expand Down
Loading