Skip to content

Commit

Permalink
Merge pull request #2137 from bmaltais/dev
Browse files Browse the repository at this point in the history
v23.0.15
  • Loading branch information
bmaltais committed Mar 21, 2024
2 parents 55386bb + aaf0396 commit 5bbb4fc
Show file tree
Hide file tree
Showing 25 changed files with 1,723 additions and 653 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ dataset/**
!dataset/**/.gitkeep
models
data
config.toml
config.toml
sd-scripts
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v23.0.14
v23.0.15
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ The GUI allows you to set the training parameters and generate and run the requi
- [No module called tkinter](#no-module-called-tkinter)
- [SDXL training](#sdxl-training)
- [Change History](#change-history)
- [2024/03/13 (v23.0.14)](#20240313-v23014)
- [2024/03/13 (v23.0.13)](#20240313-v23013)
- [2024/03/20 (v23.0.15)](#20240320-v23015)
- [2024/03/19 (v23.0.14)](#20240319-v23014)
- [2024/03/19 (v23.0.13)](#20240319-v23013)
- [2024/03/16 (v23.0.12)](#20240316-v23012)
- [New Features \& Improvements](#new-features--improvements)
- [Software Updates](#software-updates)
Expand Down Expand Up @@ -381,11 +382,19 @@ The documentation in this section will be moved to a separate document later.
## Change History
### 2024/03/13 (v23.0.14)
### 2024/03/21 (v23.0.15)
- Add support for toml dataset configuration fole to all trainers
- Add new setup menu option to install Triton 2.1.0 for Windows
- Add support for LyCORIS BOFT and DoRA and QLyCORIS options for LoHA, LoKr and LoCon
- Fix issue with vae path validation
- Other fixes
### 2024/03/19 (v23.0.14)
- Fix blip caption issue
-
### 2024/03/13 (v23.0.13)
### 2024/03/19 (v23.0.13)
- Fix issue with image samples.
Expand Down
Empty file.
10 changes: 3 additions & 7 deletions kohya_gui/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
list_files,
list_dirs,
create_refresh_button,
document_symbol
)


Expand Down Expand Up @@ -38,8 +39,7 @@ def __init__(
headless (bool): Run in headless mode without GUI.
finetuning (bool): Enable model fine-tuning.
training_type (str): The type of training to be performed.
default_vae_dir (str): Default directory for VAE models.
default_output_dir (str): Default directory for output files.
config (dict): Configuration options for the training process.
"""
self.headless = headless
self.finetuning = finetuning
Expand Down Expand Up @@ -368,10 +368,6 @@ def list_state_dirs(path):
outputs=self.resume,
show_progress=False,
)
# self.max_train_epochs = gr.Textbox(
# label='Max train epoch',
# placeholder='(Optional) Override number of epoch',
# )
self.max_data_loader_n_workers = gr.Textbox(
label="Max num workers for DataLoader",
placeholder="(Optional) Override number of epoch. Default: 8",
Expand Down Expand Up @@ -437,7 +433,7 @@ def list_log_tracker_config_files(path):
"open_folder_small",
)
self.log_tracker_config_button = gr.Button(
"📂", elem_id="open_folder_small", visible=(not headless)
document_symbol, elem_id="open_folder_small", visible=(not headless)
)
self.log_tracker_config_button.click(
get_any_file_path,
Expand Down
2 changes: 1 addition & 1 deletion kohya_gui/class_folders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gradio as gr
import os
from .common_gui import get_folder_path, scriptdir, list_dirs, create_refresh_button
from .common_gui import get_folder_path, scriptdir, list_dirs, list_files, create_refresh_button

class Folders:
"""
Expand Down
67 changes: 58 additions & 9 deletions kohya_gui/class_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os

from .common_gui import (
get_any_file_path,
get_file_path,
get_folder_path,
set_pretrained_model_name_or_path_input,
scriptdir,
Expand Down Expand Up @@ -61,6 +61,8 @@ def __init__(
self.current_train_data_dir = self.config.get(
"train_data_dir", os.path.join(scriptdir, "data")
)
self.current_dataset_config_dir = self.config.get('dataset_config_dir', os.path.join(scriptdir, "dataset_config"))


model_checkpoints = list(
list_files(
Expand All @@ -79,6 +81,21 @@ def list_models(path):
def list_train_data_dirs(path):
self.current_train_data_dir = path if not path == "" else "."
return list(list_dirs(path))

def list_dataset_config_dirs(path: str) -> list:
"""
List directories and toml files in the dataset_config directory.
Parameters:
- path (str): The path to list directories and files from.
Returns:
- list: A list of directories and files.
"""
current_dataset_config_dir = path if not path == "" else "."
# Lists all .json files in the current configuration directory, used for populating dropdown choices.
return list(list_files(current_dataset_config_dir, exts=[".toml"], all=True))


with gr.Column(), gr.Group():
# Define the input elements
Expand Down Expand Up @@ -107,7 +124,7 @@ def list_train_data_dirs(path):
visible=(not headless),
)
self.pretrained_model_name_or_path_file.click(
get_any_file_path,
get_file_path,
inputs=self.pretrained_model_name_or_path,
outputs=self.pretrained_model_name_or_path,
show_progress=False,
Expand All @@ -124,7 +141,15 @@ def list_train_data_dirs(path):
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)


with gr.Column(), gr.Row():
self.output_name = gr.Textbox(
label="Trained Model output name",
placeholder="(Name of the model to output)",
value="last",
interactive=True,
)
with gr.Row():
with gr.Column(), gr.Row():
self.train_data_dir = gr.Dropdown(
label=(
Expand Down Expand Up @@ -158,6 +183,36 @@ def list_train_data_dirs(path):
outputs=self.train_data_dir,
show_progress=False,
)
with gr.Column(), gr.Row():
# Toml directory dropdown
self.dataset_config = gr.Dropdown(
label='Dataset config file (Optional. Select the toml configuration file to use for the dataset)',
choices=[""] + list_dataset_config_dirs(self.current_dataset_config_dir),
value="",
interactive=True,
allow_custom_value=True,
)
# Refresh button for dataset_config directory
create_refresh_button(self.dataset_config, lambda: None, lambda: {"choices": [""] + list_dataset_config_dirs(self.current_dataset_config_dir)}, "open_folder_small")
# Toml directory button
self.dataset_config_folder = gr.Button(
document_symbol, elem_id='open_folder_small', elem_classes=["tool"], visible=(not self.headless)
)

# Toml directory button click event
self.dataset_config_folder.click(
get_file_path,
inputs=[self.dataset_config, gr.Textbox(value='*.toml', visible=False), gr.Textbox(value='Dataset config types', visible=False)],
outputs=self.dataset_config,
show_progress=False,
)
# Change event for dataset_config directory dropdown
self.dataset_config.change(
fn=lambda path: gr.Dropdown(choices=[""] + list_dataset_config_dirs(path)),
inputs=self.dataset_config,
outputs=self.dataset_config,
show_progress=False,
)

with gr.Row():
with gr.Column():
Expand All @@ -181,12 +236,6 @@ def list_train_data_dirs(path):
gr.Box(visible=False)

with gr.Row():
self.output_name = gr.Textbox(
label="Trained Model output name",
placeholder="(Name of the model to output)",
value="last",
interactive=True,
)
self.training_comment = gr.Textbox(
label="Training comment",
placeholder="(Optional) Add training comment to be included in metadata",
Expand Down
29 changes: 29 additions & 0 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@

ENV_EXCLUSION = ["COLAB_GPU", "RUNPOD_POD_ID"]

def calculate_max_train_steps(
total_steps: int,
train_batch_size: int,
gradient_accumulation_steps: int,
epoch: int,
reg_factor: int,
):
return int(
math.ceil(
float(total_steps)
/ int(train_batch_size)
/ int(gradient_accumulation_steps)
* int(epoch)
* int(reg_factor)
)
)

def check_if_model_exist(
output_name: str, output_dir: str, save_model_as: str, headless: bool = False
) -> bool:
Expand Down Expand Up @@ -1077,6 +1094,11 @@ def run_cmd_advanced_training(**kwargs):
if color_aug:
run_cmd += " --color_aug"

dataset_config = kwargs.get("dataset_config")
if dataset_config:
dataset_config = os.path.abspath(os.path.normpath(dataset_config))
run_cmd += f' --dataset_config="{dataset_config}"'

dataset_repeats = kwargs.get("dataset_repeats")
if dataset_repeats:
run_cmd += f' --dataset_repeats="{dataset_repeats}"'
Expand Down Expand Up @@ -1753,6 +1775,13 @@ def validate_path(
if key in ["output_dir", "logging_dir"]:
if not validate_path(value, key, create_if_missing=True):
return False
elif key in ["vae"]:
# Check if it matches the Hugging Face model pattern
if re.match(r"^[\w-]+\/[\w-]+$", value):
log.info("Checking vae... huggingface.co model, skipping validation")
else:
if not validate_path(value, key):
return False
else:
if key not in ["pretrained_model_name_or_path"]:
if not validate_path(value, key):
Expand Down
Loading

0 comments on commit 5bbb4fc

Please sign in to comment.