Skip to content

Commit

Permalink
Merge pull request #2055 from bmaltais/dev
Browse files Browse the repository at this point in the history
v23.0.2
  • Loading branch information
bmaltais committed Mar 10, 2024
2 parents 7dae63e + ab188dd commit 7895de6
Show file tree
Hide file tree
Showing 16 changed files with 243 additions and 232 deletions.
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v23.0.1
v23.0.2
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ The GUI allows you to set the training parameters and generate and run the requi
- [No module called tkinter](#no-module-called-tkinter)
- [SDXL training](#sdxl-training)
- [Change History](#change-history)
- [2024/03/10 (v23.0.1)](#20240310-v2301)
- [2024/03/10 (v23.0.2)](#20240310-v2302)
- [2024/03/09 (v23.0.1)](#20240309-v2301)
- [2024/03/02 (v23.0.0)](#20240302-v2300)

## 🦒 Colab
Expand Down Expand Up @@ -364,7 +365,11 @@ The documentation in this section will be moved to a separate document later.
## Change History
### 2024/03/10 (v23.0.1)
### 2024/03/10 (v23.0.2)
- Improve validation of path provided by users before running training
### 2024/03/09 (v23.0.1)
- Update bitsandbytes module to 0.43.0 as it provide native windows support
- Minor fixes to code
Expand Down
9 changes: 7 additions & 2 deletions gui.bat
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
@echo off

set PYTHON_VER=3.10.9

:: Deactivate the virtual environment
call .\venv\Scripts\deactivate.bat

:: Calling external python program to check for local modules
:: python .\setup\check_local_modules.py --no_question
:: Check if Python version meets the recommended version
python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul
if errorlevel 1 (
echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run.
)

:: Activate the virtual environment
call .\venv\Scripts\activate.bat
Expand Down
25 changes: 12 additions & 13 deletions kohya_gui/class_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄

default_models = [
'stabilityai/stable-diffusion-xl-base-1.0',
'stabilityai/stable-diffusion-xl-refiner-1.0',
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
]

class SourceModel:
def __init__(
Expand All @@ -39,19 +51,6 @@ def __init__(
self.save_model_as_choices = save_model_as_choices
self.finetuning = finetuning

default_models = [
'stabilityai/stable-diffusion-xl-base-1.0',
'stabilityai/stable-diffusion-xl-refiner-1.0',
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned',
'stabilityai/stable-diffusion-2-1-base',
'stabilityai/stable-diffusion-2-base',
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned',
'stabilityai/stable-diffusion-2-1',
'stabilityai/stable-diffusion-2',
'runwayml/stable-diffusion-v1-5',
'CompVis/stable-diffusion-v1-4',
]

from .common_gui import create_refresh_button

default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs")
Expand Down
152 changes: 138 additions & 14 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import json

from .custom_logging import setup_logging
from datetime import datetime

# Set up logging
log = setup_logging()
Expand Down Expand Up @@ -699,7 +698,7 @@ def run_cmd_advanced_training(**kwargs):
if "additional_parameters" in kwargs:
run_cmd += f' {kwargs["additional_parameters"]}'

if "block_lr" in kwargs:
if "block_lr" in kwargs and kwargs["block_lr"] != "":
run_cmd += f' --block_lr="{kwargs["block_lr"]}"'

if kwargs.get("bucket_no_upscale"):
Expand Down Expand Up @@ -1143,12 +1142,12 @@ def run_cmd_advanced_training(**kwargs):

def verify_image_folder_pattern(folder_path):
false_response = True # temporarily set to true to prevent stopping training in case of false positive
true_response = True

log.info(f"Verifying image folder pattern of {folder_path}...")
# Check if the folder exists
if not os.path.isdir(folder_path):
log.error(
f"The provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
f"...the provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
)
return false_response

Expand Down Expand Up @@ -1176,22 +1175,22 @@ def verify_image_folder_pattern(folder_path):
non_matching_subfolders = set(subfolders) - set(matching_subfolders)
if non_matching_subfolders:
log.error(
f"The following folders do not match the required pattern <number>_<text>: {', '.join(non_matching_subfolders)}"
f"...the following folders do not match the required pattern <number>_<text>: {', '.join(non_matching_subfolders)}"
)
log.error(
f"Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
f"...please follow the folder structure documentation found at docs\image_folder_structure.md ..."
)
return false_response

# Check if no sub-folders exist
if not matching_subfolders:
log.error(
f"No image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
f"...no image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ..."
)
return false_response

log.info(f"Valid image folder names found in: {folder_path}")
return true_response
log.info(f"...valid")
return True


def SaveConfigFile(
Expand Down Expand Up @@ -1231,7 +1230,9 @@ def save_to_file(content):
def check_duplicate_filenames(
folder_path, image_extension=[".gif", ".png", ".jpg", ".jpeg", ".webp"]
):
log.info("Checking for duplicate image filenames in training data directory...")
duplicate = False

log.info(f"Checking for duplicate image filenames in training data directory {folder_path}...")
for root, dirs, files in os.walk(folder_path):
filenames = {}
for file in files:
Expand All @@ -1241,15 +1242,138 @@ def check_duplicate_filenames(
if filename in filenames:
existing_path = filenames[filename]
if existing_path != full_path:
print(
f"Warning: Same filename '{filename}' with different image extension found. This will cause training issues. Rename one of the file."
log.warning(
f"...same filename '{filename}' with different image extension found. This will cause training issues. Rename one of the file."
)
print(f"Existing file: {existing_path}")
print(f"Current file: {full_path}")
log.warning(f" Existing file: {existing_path}")
log.warning(f" Current file: {full_path}")
duplicate = True
else:
filenames[filename] = full_path
if not duplicate:
log.info("...valid")


def validate_paths(headless:bool = False, **kwargs):
from .class_source_model import default_models

pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path")
train_data_dir = kwargs.get("train_data_dir")
reg_data_dir = kwargs.get("reg_data_dir")
output_dir = kwargs.get("output_dir")
logging_dir = kwargs.get("logging_dir")
lora_network_weights= kwargs.get("lora_network_weights")
finetune_image_folder = kwargs.get("finetune_image_folder")
resume = kwargs.get("resume")
vae = kwargs.get("vae")

if pretrained_model_name_or_path is not None:
log.info(f"Validating model file or folder path {pretrained_model_name_or_path} existence...")

# Check if it matches the Hugging Face model pattern
if re.match(r'^[\w-]+\/[\w-]+$', pretrained_model_name_or_path):
log.info("...huggingface.co model, skipping validation")
elif pretrained_model_name_or_path not in default_models:
# If not one of the default models, check if it's a valid local path
if not os.path.exists(pretrained_model_name_or_path):
log.error(f"...source model path '{pretrained_model_name_or_path}' is missing or does not exist")
return False
else:
log.info("...valid")
else:
log.info("...valid")

# Check if train_data_dir is valid
if train_data_dir != None:
log.info(f"Validating training data folder path {train_data_dir} existence...")
if not train_data_dir or not os.path.exists(train_data_dir):
log.error(f"Image folder path '{train_data_dir}' is missing or does not exist")
return False
else:
log.info("...valid")

# Check if there are files with the same filename but different image extension... warn the user if it is the case.
check_duplicate_filenames(train_data_dir)

if not verify_image_folder_pattern(folder_path=train_data_dir):
return False

if finetune_image_folder != None:
log.info(f"Validating finetuning image folder path {finetune_image_folder} existence...")
if not finetune_image_folder or not os.path.exists(finetune_image_folder):
log.error(f"Image folder path '{finetune_image_folder}' is missing or does not exist")
return False
else:
log.info("...valid")

if reg_data_dir != None:
if reg_data_dir != "":
log.info(f"Validating regularisation data folder path {reg_data_dir} existence...")
if not os.path.exists(reg_data_dir):
log.error("...regularisation folder does not exist")
return False

if not verify_image_folder_pattern(folder_path=reg_data_dir):
return False
log.info("...valid")
else:
log.info("Regularisation folder not specified, skipping validation")

if output_dir != None:
log.info(f"Validating output folder path {output_dir} existence...")
if output_dir == "" or not os.path.exists(output_dir):
log.error("...output folder path is missing or invalid")
return False
else:
log.info("...valid")

if logging_dir != None:
if logging_dir != "":
log.info(f"Validating logging folder path {logging_dir} existence...")
if not os.path.exists(logging_dir):
log.error("...logging folder path is missing or invalid")
return False
else:
log.info("...valid")
else:
log.info("Logging folder not specified, skipping validation")

if lora_network_weights != None:
if lora_network_weights != "":
log.info(f"Validating LoRA Network Weight file path {lora_network_weights} existence...")
if not os.path.exists(lora_network_weights):
log.error("...path is invalid")
return False
else:
log.info("...valid")
else:
log.info("LoRA Network Weight file not specified, skipping validation")

if resume != None:
if resume != "":
log.info(f"Validating model resume file path {resume} existence...")
if not os.path.exists(resume):
log.error("...path is invalid")
return False
else:
log.info("...valid")
else:
log.info("Model resume file not specified, skipping validation")

if vae != None:
if vae != "":
log.info(f"Validating VAE file path {vae} existence...")
if not os.path.exists(vae):
log.error("...vae path is invalid")
return False
else:
log.info("...valid")
else:
log.info("VAE file not specified, skipping validation")


return True

def is_file_writable(file_path):
if not os.path.exists(file_path):
# print(f"File '{file_path}' does not exist.")
Expand Down
49 changes: 11 additions & 38 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
# v1: initial release
# v2: add open and save folder icons
# v3: Add new Utilities tab for Dreambooth folder preparation
# v3.1: Adding captionning of images to utilities

import gradio as gr
import json
import math
import os
import subprocess
import sys
import pathlib
from datetime import datetime
Expand All @@ -19,11 +13,10 @@
run_cmd_advanced_training,
update_my_data,
check_if_model_exist,
output_message,
verify_image_folder_pattern,
SaveConfigFile,
save_to_file,
scriptdir,
validate_paths,
)
from .class_configuration_file import ConfigurationFile
from .class_source_model import SourceModel
Expand Down Expand Up @@ -406,36 +399,16 @@ def train_model(

headless_bool = True if headless.get("label") == "True" else False

if pretrained_model_name_or_path == "":
output_message(
msg="Source model information is missing", headless=headless_bool
)
return

if train_data_dir == "":
output_message(msg="Image folder path is missing", headless=headless_bool)
return

if not os.path.exists(train_data_dir):
output_message(msg="Image folder does not exist", headless=headless_bool)
return

if not verify_image_folder_pattern(train_data_dir):
return

if reg_data_dir != "":
if not os.path.exists(reg_data_dir):
output_message(
msg="Regularisation folder does not exist",
headless=headless_bool,
)
return

if not verify_image_folder_pattern(reg_data_dir):
return

if output_dir == "":
output_message(msg="Output folder path is missing", headless=headless_bool)
if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
train_data_dir=train_data_dir,
reg_data_dir=reg_data_dir,
headless=headless_bool,
logging_dir=logging_dir,
resume=resume,
vae=vae,
):
return

if not print_only_bool and check_if_model_exist(
Expand Down
Loading

0 comments on commit 7895de6

Please sign in to comment.