From 33ccb844464708c03b64d2a745f0b443fff18e36 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sat, 9 Mar 2024 20:50:25 -0500 Subject: [PATCH 1/4] Add better python, model and data dir validation --- .release | 2 +- gui.bat | 9 +++++++-- kohya_gui/class_source_model.py | 25 ++++++++++++------------- kohya_gui/lora_gui.py | 22 +++++++++++----------- setup.bat | 2 +- setup/setup_common.py | 23 ++++++++++++++++++++++- setup/setup_linux.py | 1 + setup/setup_runpod.py | 1 + setup/setup_windows.py | 1 + setup/validate_requirements.py | 6 +++--- 10 files changed, 60 insertions(+), 32 deletions(-) diff --git a/.release b/.release index 0ba9aef02..0cd624f8f 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v23.0.1 +v23.0.1 \ No newline at end of file diff --git a/gui.bat b/gui.bat index 7b4154150..1afbaead5 100644 --- a/gui.bat +++ b/gui.bat @@ -1,10 +1,15 @@ @echo off +set PYTHON_VER=3.10.9 + :: Deactivate the virtual environment call .\venv\Scripts\deactivate.bat -:: Calling external python program to check for local modules -:: python .\setup\check_local_modules.py --no_question +:: Check if Python version meets the recommended version +python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul +if errorlevel 1 ( + echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run. +) :: Activate the virtual environment call .\venv\Scripts\activate.bat diff --git a/kohya_gui/class_source_model.py b/kohya_gui/class_source_model.py index 9017fcfa7..f82525f69 100644 --- a/kohya_gui/class_source_model.py +++ b/kohya_gui/class_source_model.py @@ -15,6 +15,18 @@ save_style_symbol = '\U0001f4be' # 💾 document_symbol = '\U0001F4C4' # 📄 +default_models = [ + 'stabilityai/stable-diffusion-xl-base-1.0', + 'stabilityai/stable-diffusion-xl-refiner-1.0', + 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', + 'stabilityai/stable-diffusion-2-1-base', + 'stabilityai/stable-diffusion-2-base', + 'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', + 'stabilityai/stable-diffusion-2-1', + 'stabilityai/stable-diffusion-2', + 'runwayml/stable-diffusion-v1-5', + 'CompVis/stable-diffusion-v1-4', +] class SourceModel: def __init__( @@ -39,19 +51,6 @@ def __init__( self.save_model_as_choices = save_model_as_choices self.finetuning = finetuning - default_models = [ - 'stabilityai/stable-diffusion-xl-base-1.0', - 'stabilityai/stable-diffusion-xl-refiner-1.0', - 'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', - 'stabilityai/stable-diffusion-2-1-base', - 'stabilityai/stable-diffusion-2-base', - 'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', - 'stabilityai/stable-diffusion-2-1', - 'stabilityai/stable-diffusion-2', - 'runwayml/stable-diffusion-v1-5', - 'CompVis/stable-diffusion-v1-4', - ] - from .common_gui import create_refresh_button default_data_dir = default_data_dir if default_data_dir is not None else os.path.join(scriptdir, "outputs") diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index 16c3162d1..a262a3854 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -560,23 +560,23 @@ def train_model( log.info(f"Start training LoRA {LoRA_type} ...") headless_bool = True if headless.get("label") == "True" else False - if pretrained_model_name_or_path == "": - output_message( - msg="Source model information is missing", headless=headless_bool - ) - return + from .class_source_model import default_models - if train_data_dir == "": - output_message(msg="Image folder path is missing", headless=headless_bool) + # Check if the pretrained_model_name_or_path is valid + if pretrained_model_name_or_path not in default_models: + # If not one of the default models, check if it's a valid path + if not pretrained_model_name_or_path or not os.path.exists(pretrained_model_name_or_path): + log.error(f"Source model path '{pretrained_model_name_or_path}' is missing or does not exist") + return + + # Check if train_data_dir is valid + if not train_data_dir or not os.path.exists(train_data_dir): + log.error(f"Image folder path '{train_data_dir}' is missing or does not exist") return # Check if there are files with the same filename but different image extension... warn the user if it is the case. check_duplicate_filenames(train_data_dir) - if not os.path.exists(train_data_dir): - output_message(msg="Image folder does not exist", headless=headless_bool) - return - if not verify_image_folder_pattern(train_data_dir): return diff --git a/setup.bat b/setup.bat index d88cfaffc..7d1cc70f1 100644 --- a/setup.bat +++ b/setup.bat @@ -5,7 +5,7 @@ set PYTHON_VER=3.10.9 :: Check if Python version meets the recommended version python --version 2>nul | findstr /b /c:"Python %PYTHON_VER%" >nul if errorlevel 1 ( - echo Warning: Python version %PYTHON_VER% is recommended. + echo Warning: Python version %PYTHON_VER% is required. Kohya_ss GUI will most likely fail to run. ) IF NOT EXIST venv ( diff --git a/setup/setup_common.py b/setup/setup_common.py index edddec4ce..d4dbf8108 100644 --- a/setup/setup_common.py +++ b/setup/setup_common.py @@ -10,9 +10,30 @@ import platform import pkg_resources +from packaging import version + errors = 0 # Define the 'errors' variable before using it log = logging.getLogger('sd') +def check_python_version(): + """ + Check if the current Python version is >= 3.10.9 and < 3.11.0 + + Returns: + bool: True if the current Python version is valid, False otherwise. + """ + min_version = (3, 10, 9) + max_version = (3, 11, 0) + current_version = sys.version_info + + log.info(f"Python version is {sys.version}") + + if not (min_version <= current_version < max_version): + log.error(f"The current version of python is not appropriate to run Kohya_ss GUI") + log.error("The python version need to be greater or equal to 3.10.9 and less than 3.11.0") + + return (min_version <= current_version < max_version) + def update_submodule(): """ Ensure the submodule is initialized and updated. @@ -352,7 +373,7 @@ def check_repo_version(): # pylint: disable=unused-argument with open(os.path.join('./.release'), 'r', encoding='utf8') as file: release= file.read() - log.info(f'Version: {release}') + log.info(f'Kohya_ss GUI version: {release}') else: log.debug('Could not read release...') diff --git a/setup/setup_linux.py b/setup/setup_linux.py index 76ea9dacf..de627e408 100644 --- a/setup/setup_linux.py +++ b/setup/setup_linux.py @@ -25,6 +25,7 @@ def main_menu(platform_requirements_file, show_stdout: bool = False, no_run_acce if __name__ == '__main__': + python_ver = setup_common.check_python_version() setup_common.ensure_base_requirements() setup_common.setup_logging() diff --git a/setup/setup_runpod.py b/setup/setup_runpod.py index d909191d4..f33c570ce 100644 --- a/setup/setup_runpod.py +++ b/setup/setup_runpod.py @@ -59,6 +59,7 @@ def main_menu(platform_requirements_file): if __name__ == '__main__': + python_ver = setup_common.check_python_version() setup_common.ensure_base_requirements() setup_common.setup_logging() diff --git a/setup/setup_windows.py b/setup/setup_windows.py index 9b6b6fe6d..bf1dbbc34 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -223,6 +223,7 @@ def main_menu(): if __name__ == "__main__": + python_ver = setup_common.check_python_version() setup_common.ensure_base_requirements() setup_common.setup_logging() main_menu() diff --git a/setup/validate_requirements.py b/setup/validate_requirements.py index df88689f2..ee924a757 100644 --- a/setup/validate_requirements.py +++ b/setup/validate_requirements.py @@ -1,5 +1,4 @@ import os -import re import sys import shutil import argparse @@ -88,8 +87,7 @@ def check_torch(): except Exception as e: log.error(f'Could not load torch: {e}') sys.exit(1) - - + def main(): setup_common.check_repo_version() # Parse command line arguments @@ -107,6 +105,8 @@ def main(): torch_ver = check_torch() + python_ver = setup_common.check_python_version() + setup_common.update_submodule() if args.requirements: From 40702a6866c37509ccda852e44a0961e05cd29b0 Mon Sep 17 00:00:00 2001 From: Ashley Kleynhans Date: Sun, 10 Mar 2024 19:44:59 +0200 Subject: [PATCH 2/4] Bump RunPod requirements to match the Linux ones --- requirements_runpod.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements_runpod.txt b/requirements_runpod.txt index 9b9703696..b301aca3c 100644 --- a/requirements_runpod.txt +++ b/requirements_runpod.txt @@ -1,5 +1,5 @@ -torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage -xformers==0.0.21 bitsandbytes==0.41.1 -tensorboard==2.14.1 tensorflow==2.14.0 wheel +torch==2.1.2+cu118 torchvision==0.16.2+cu118 xformers==0.0.23.post1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 # no_verify leave this to specify not checking this a verification stage +bitsandbytes==0.41.2 +tensorboard==2.15.2 tensorflow==2.15.0.post1 wheel tensorrt -r requirements.txt From 40314706c4e49f1d3ef741eeba4a4b279d398d15 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 10 Mar 2024 15:27:33 -0400 Subject: [PATCH 3/4] Improve user file/folder input validation before running training scripts --- .release | 2 +- README.md | 9 +- kohya_gui/common_gui.py | 152 ++++++++++++++++++++++++++--- kohya_gui/dreambooth_gui.py | 49 +++------- kohya_gui/finetune_gui.py | 53 +++------- kohya_gui/lora_gui.py | 66 +++---------- kohya_gui/textual_inversion_gui.py | 70 +++---------- 7 files changed, 192 insertions(+), 209 deletions(-) diff --git a/.release b/.release index 0cd624f8f..56a37ffda 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v23.0.1 \ No newline at end of file +v23.0.2 \ No newline at end of file diff --git a/README.md b/README.md index 1a4205650..08d1ed99b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,8 @@ The GUI allows you to set the training parameters and generate and run the requi - [No module called tkinter](#no-module-called-tkinter) - [SDXL training](#sdxl-training) - [Change History](#change-history) - - [2024/03/10 (v23.0.1)](#20240310-v2301) + - [2024/03/10 (v23.0.2)](#20240310-v2302) + - [2024/03/09 (v23.0.1)](#20240309-v2301) - [2024/03/02 (v23.0.0)](#20240302-v2300) ## 🦒 Colab @@ -364,7 +365,11 @@ The documentation in this section will be moved to a separate document later. ## Change History -### 2024/03/10 (v23.0.1) +### 2024/03/10 (v23.0.2) + +- Improve validation of path provided by users before running training + +### 2024/03/09 (v23.0.1) - Update bitsandbytes module to 0.43.0 as it provide native windows support - Minor fixes to code diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 003c3aa7e..b8d510bdc 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -9,7 +9,6 @@ import json from .custom_logging import setup_logging -from datetime import datetime # Set up logging log = setup_logging() @@ -699,7 +698,7 @@ def run_cmd_advanced_training(**kwargs): if "additional_parameters" in kwargs: run_cmd += f' {kwargs["additional_parameters"]}' - if "block_lr" in kwargs: + if "block_lr" in kwargs and kwargs["block_lr"] != "": run_cmd += f' --block_lr="{kwargs["block_lr"]}"' if kwargs.get("bucket_no_upscale"): @@ -1143,12 +1142,12 @@ def run_cmd_advanced_training(**kwargs): def verify_image_folder_pattern(folder_path): false_response = True # temporarily set to true to prevent stopping training in case of false positive - true_response = True + log.info(f"Verifying image folder pattern of {folder_path}...") # Check if the folder exists if not os.path.isdir(folder_path): log.error( - f"The provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ..." + f"...the provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ..." ) return false_response @@ -1176,22 +1175,22 @@ def verify_image_folder_pattern(folder_path): non_matching_subfolders = set(subfolders) - set(matching_subfolders) if non_matching_subfolders: log.error( - f"The following folders do not match the required pattern _: {', '.join(non_matching_subfolders)}" + f"...the following folders do not match the required pattern _: {', '.join(non_matching_subfolders)}" ) log.error( - f"Please follow the folder structure documentation found at docs\image_folder_structure.md ..." + f"...please follow the folder structure documentation found at docs\image_folder_structure.md ..." ) return false_response # Check if no sub-folders exist if not matching_subfolders: log.error( - f"No image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ..." + f"...no image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ..." ) return false_response - log.info(f"Valid image folder names found in: {folder_path}") - return true_response + log.info(f"...valid") + return True def SaveConfigFile( @@ -1231,7 +1230,9 @@ def save_to_file(content): def check_duplicate_filenames( folder_path, image_extension=[".gif", ".png", ".jpg", ".jpeg", ".webp"] ): - log.info("Checking for duplicate image filenames in training data directory...") + duplicate = False + + log.info(f"Checking for duplicate image filenames in training data directory {folder_path}...") for root, dirs, files in os.walk(folder_path): filenames = {} for file in files: @@ -1241,15 +1242,138 @@ def check_duplicate_filenames( if filename in filenames: existing_path = filenames[filename] if existing_path != full_path: - print( - f"Warning: Same filename '{filename}' with different image extension found. This will cause training issues. Rename one of the file." + log.warning( + f"...same filename '{filename}' with different image extension found. This will cause training issues. Rename one of the file." ) - print(f"Existing file: {existing_path}") - print(f"Current file: {full_path}") + log.warning(f" Existing file: {existing_path}") + log.warning(f" Current file: {full_path}") + duplicate = True else: filenames[filename] = full_path + if not duplicate: + log.info("...valid") +def validate_paths(headless:bool = False, **kwargs): + from .class_source_model import default_models + + pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path") + train_data_dir = kwargs.get("train_data_dir") + reg_data_dir = kwargs.get("reg_data_dir") + output_dir = kwargs.get("output_dir") + logging_dir = kwargs.get("logging_dir") + lora_network_weights= kwargs.get("lora_network_weights") + finetune_image_folder = kwargs.get("finetune_image_folder") + resume = kwargs.get("resume") + vae = kwargs.get("vae") + + if pretrained_model_name_or_path is not None: + log.info(f"Validating model file or folder path {pretrained_model_name_or_path} existence...") + + # Check if it matches the Hugging Face model pattern + if re.match(r'^[\w-]+\/[\w-]+$', pretrained_model_name_or_path): + log.info("...huggingface.co model, skipping validation") + elif pretrained_model_name_or_path not in default_models: + # If not one of the default models, check if it's a valid local path + if not os.path.exists(pretrained_model_name_or_path): + log.error(f"...source model path '{pretrained_model_name_or_path}' is missing or does not exist") + return False + else: + log.info("...valid") + else: + log.info("...valid") + + # Check if train_data_dir is valid + if train_data_dir != None: + log.info(f"Validating training data folder path {train_data_dir} existance...") + if not train_data_dir or not os.path.exists(train_data_dir): + log.error(f"Image folder path '{train_data_dir}' is missing or does not exist") + return False + else: + log.info("...valid") + + # Check if there are files with the same filename but different image extension... warn the user if it is the case. + check_duplicate_filenames(train_data_dir) + + if not verify_image_folder_pattern(folder_path=train_data_dir): + return False + + if finetune_image_folder != None: + log.info(f"Validating finetuning image folder path {finetune_image_folder} existance...") + if not finetune_image_folder or not os.path.exists(finetune_image_folder): + log.error(f"Image folder path '{finetune_image_folder}' is missing or does not exist") + return False + else: + log.info("...valid") + + if reg_data_dir != None: + if reg_data_dir != "": + log.info(f"Validating regularisation data folder path {reg_data_dir} existance...") + if not os.path.exists(reg_data_dir): + log.error("...regularisation folder does not exist") + return False + + if not verify_image_folder_pattern(folder_path=reg_data_dir): + return False + log.info("...valid") + else: + log.info("Regularisation folder not specified, skipping validation") + + if output_dir != None: + log.info(f"Validating output folder path {output_dir} existance...") + if output_dir == "" or not os.path.exists(output_dir): + log.error("...output folder path is missing or invalid") + return False + else: + log.info("...valid") + + if logging_dir != None: + if logging_dir != "": + log.info(f"Validating logging folder path {logging_dir} existance...") + if not os.path.exists(logging_dir): + log.error("...logging folder path is missing or invalid") + return False + else: + log.info("...valid") + else: + log.info("Logging folder not specified, skipping validation") + + if lora_network_weights != None: + if lora_network_weights != "": + log.info(f"Validating LoRA Network Weight file path {lora_network_weights} existance...") + if not os.path.exists(lora_network_weights): + log.error("...path is invalid") + return False + else: + log.info("...valid") + else: + log.info("LoRA Network Weight file not specified, skipping validation") + + if resume != None: + if resume != "": + log.info(f"Validating model resume file path {resume} existance...") + if not os.path.exists(resume): + log.error("...path is invalid") + return False + else: + log.info("...valid") + else: + log.info("Model resume file not specified, skipping validation") + + if vae != None: + if vae != "": + log.info(f"Validating VAE file path {vae} existance...") + if not os.path.exists(vae): + log.error("...vae path is invalid") + return False + else: + log.info("...valid") + else: + log.info("VAE file not specified, skipping validation") + + + return True + def is_file_writable(file_path): if not os.path.exists(file_path): # print(f"File '{file_path}' does not exist.") diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 981a8bfb1..be382c912 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -1,13 +1,7 @@ -# v1: initial release -# v2: add open and save folder icons -# v3: Add new Utilities tab for Dreambooth folder preparation -# v3.1: Adding captionning of images to utilities - import gradio as gr import json import math import os -import subprocess import sys import pathlib from datetime import datetime @@ -19,11 +13,10 @@ run_cmd_advanced_training, update_my_data, check_if_model_exist, - output_message, - verify_image_folder_pattern, SaveConfigFile, save_to_file, scriptdir, + validate_paths, ) from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel @@ -406,36 +399,16 @@ def train_model( headless_bool = True if headless.get("label") == "True" else False - if pretrained_model_name_or_path == "": - output_message( - msg="Source model information is missing", headless=headless_bool - ) - return - - if train_data_dir == "": - output_message(msg="Image folder path is missing", headless=headless_bool) - return - - if not os.path.exists(train_data_dir): - output_message(msg="Image folder does not exist", headless=headless_bool) - return - - if not verify_image_folder_pattern(train_data_dir): - return - - if reg_data_dir != "": - if not os.path.exists(reg_data_dir): - output_message( - msg="Regularisation folder does not exist", - headless=headless_bool, - ) - return - - if not verify_image_folder_pattern(reg_data_dir): - return - - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) + if not validate_paths( + output_dir=output_dir, + pretrained_model_name_or_path=pretrained_model_name_or_path, + train_data_dir=train_data_dir, + reg_data_dir=reg_data_dir, + headless=headless_bool, + logging_dir=logging_dir, + resume=resume, + vae=vae, + ): return if not print_only_bool and check_if_model_exist( diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index 3cb2470c6..a1fa805e8 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -7,7 +7,6 @@ import pathlib from datetime import datetime from .common_gui import ( - get_folder_path, get_file_path, get_saveasfile_path, save_inference_file, @@ -15,10 +14,10 @@ color_aug_changed, update_my_data, check_if_model_exist, - output_message, SaveConfigFile, save_to_file, scriptdir, + validate_paths, ) from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel @@ -436,39 +435,24 @@ def train_model( headless_bool = True if headless.get("label") == "True" else False - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) + if train_dir != "" and not os.path.exists(train_dir): + os.mkdir(train_dir) + + if not validate_paths( + output_dir=output_dir, + pretrained_model_name_or_path=pretrained_model_name_or_path, + finetune_image_folder=image_folder, + headless=headless_bool, + logging_dir=logging_dir, + resume=resume, + ): return - if train_dir is None or train_dir.strip() == "": - train_dir = output_dir - if not print_only_bool and check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): return - # if float(noise_offset) > 0 and ( - # multires_noise_iterations > 0 or multires_noise_discount > 0 - # ): - # output_message( - # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", - # title='Error', - # headless=headless_bool, - # ) - # return - - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - # create caption json file if generate_caption_database: - if train_dir != "" and not os.path.exists(train_dir): - os.mkdir(train_dir) - run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/finetune/merge_captions_to_metadata.py"' if caption_extension == "": run_cmd += f' --caption_extension=".caption"' @@ -488,19 +472,6 @@ def train_model( # Run the command subprocess.run(run_cmd, shell=True, env=env) - # check pretrained_model_name_or_path - - if not os.path.exists(pretrained_model_name_or_path): - try: - from modules import sd_models - - info = sd_models.get_closet_checkpoint_match(pretrained_model_name_or_path) - if info is not None: - pretrained_model_name_or_path = info.filename - - except Exception: - pass - # create images buckets if generate_image_buckets: run_cmd = fr'{PYTHON} "{scriptdir}/sd-scripts/finetune/prepare_buckets_latents.py"' diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index a262a3854..25fbdf613 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -2,7 +2,6 @@ import json import math import os -import lycoris from datetime import datetime from .common_gui import ( get_file_path, @@ -13,11 +12,10 @@ update_my_data, check_if_model_exist, output_message, - verify_image_folder_pattern, SaveConfigFile, save_to_file, - check_duplicate_filenames, scriptdir, + validate_paths, ) from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel @@ -560,39 +558,17 @@ def train_model( log.info(f"Start training LoRA {LoRA_type} ...") headless_bool = True if headless.get("label") == "True" else False - from .class_source_model import default_models - - # Check if the pretrained_model_name_or_path is valid - if pretrained_model_name_or_path not in default_models: - # If not one of the default models, check if it's a valid path - if not pretrained_model_name_or_path or not os.path.exists(pretrained_model_name_or_path): - log.error(f"Source model path '{pretrained_model_name_or_path}' is missing or does not exist") - return - - # Check if train_data_dir is valid - if not train_data_dir or not os.path.exists(train_data_dir): - log.error(f"Image folder path '{train_data_dir}' is missing or does not exist") - return - - # Check if there are files with the same filename but different image extension... warn the user if it is the case. - check_duplicate_filenames(train_data_dir) - - if not verify_image_folder_pattern(train_data_dir): - return - - if reg_data_dir != "": - if not os.path.exists(reg_data_dir): - output_message( - msg="Regularisation folder does not exist", - headless=headless_bool, - ) - return - - if not verify_image_folder_pattern(reg_data_dir): - return - - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) + if not validate_paths( + output_dir=output_dir, + pretrained_model_name_or_path=pretrained_model_name_or_path, + train_data_dir=train_data_dir, + reg_data_dir=reg_data_dir, + headless=headless_bool, + logging_dir=logging_dir, + resume=resume, + vae=vae, + lora_network_weights=lora_network_weights, + ): return if int(bucket_reso_steps) < 1: @@ -612,16 +588,6 @@ def train_model( ) return - # if float(noise_offset) > 0 and ( - # multires_noise_iterations > 0 or multires_noise_discount > 0 - # ): - # output_message( - # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", - # title='Error', - # headless=headless_bool, - # ) - # return - if not os.path.exists(output_dir): os.makedirs(output_dir) @@ -637,14 +603,6 @@ def train_model( ): return - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - # If string is empty set string to 0. if text_encoder_lr == "": text_encoder_lr = 0 diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index 50674603f..39391a9a5 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -1,13 +1,7 @@ -# v1: initial release -# v2: add open and save folder icons -# v3: Add new Utilities tab for Dreambooth folder preparation -# v3.1: Adding captionning of images to utilities - import gradio as gr import json import math import os -import subprocess import pathlib from datetime import datetime from .common_gui import ( @@ -19,12 +13,12 @@ update_my_data, check_if_model_exist, output_message, - verify_image_folder_pattern, SaveConfigFile, save_to_file, scriptdir, list_files, create_refresh_button, + validate_paths, ) from .class_configuration_file import ConfigurationFile from .class_source_model import SourceModel @@ -411,36 +405,16 @@ def train_model( headless_bool = True if headless.get("label") == "True" else False - if pretrained_model_name_or_path == "": - output_message( - msg="Source model information is missing", headless=headless_bool - ) - return - - if train_data_dir == "": - output_message(msg="Image folder path is missing", headless=headless_bool) - return - - if not os.path.exists(train_data_dir): - output_message(msg="Image folder does not exist", headless=headless_bool) - return - - if not verify_image_folder_pattern(train_data_dir): - return - - if reg_data_dir != "": - if not os.path.exists(reg_data_dir): - output_message( - msg="Regularisation folder does not exist", - headless=headless_bool, - ) - return - - if not verify_image_folder_pattern(reg_data_dir): - return - - if output_dir == "": - output_message(msg="Output folder path is missing", headless=headless_bool) + if not validate_paths( + output_dir=output_dir, + pretrained_model_name_or_path=pretrained_model_name_or_path, + train_data_dir=train_data_dir, + reg_data_dir=reg_data_dir, + headless=headless_bool, + logging_dir=logging_dir, + resume=resume, + vae=vae, + ): return if token_string == "": @@ -451,30 +425,9 @@ def train_model( output_message(msg="Init word is missing", headless=headless_bool) return - if not os.path.exists(output_dir): - os.makedirs(output_dir) - if not print_only_bool and check_if_model_exist(output_name, output_dir, save_model_as, headless_bool): return - # if float(noise_offset) > 0 and ( - # multires_noise_iterations > 0 or multires_noise_discount > 0 - # ): - # output_message( - # msg="noise offset and multires_noise can't be set at the same time. Only use one or the other.", - # title='Error', - # headless=headless_bool, - # ) - # return - - # if optimizer == 'Adafactor' and lr_warmup != '0': - # output_message( - # msg="Warning: lr_scheduler is set to 'Adafactor', so 'LR warmup (% of steps)' will be considered 0.", - # title='Warning', - # headless=headless_bool, - # ) - # lr_warmup = '0' - # Get a list of all subfolders in train_data_dir subfolders = [ f @@ -570,7 +523,6 @@ def train_model( cache_latents=cache_latents, cache_latents_to_disk=cache_latents_to_disk, caption_dropout_every_n_epochs=caption_dropout_every_n_epochs, - caption_dropout_rate=caption_dropout_rate, caption_extension=caption_extension, clip_skip=clip_skip, color_aug=color_aug, From 07e7d3acf88364a74e2ff013bbd96151c79b6db4 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 10 Mar 2024 15:34:29 -0400 Subject: [PATCH 4/4] Fix typos --- kohya_gui/common_gui.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index b8d510bdc..12f2d8347 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -1285,7 +1285,7 @@ def validate_paths(headless:bool = False, **kwargs): # Check if train_data_dir is valid if train_data_dir != None: - log.info(f"Validating training data folder path {train_data_dir} existance...") + log.info(f"Validating training data folder path {train_data_dir} existence...") if not train_data_dir or not os.path.exists(train_data_dir): log.error(f"Image folder path '{train_data_dir}' is missing or does not exist") return False @@ -1299,7 +1299,7 @@ def validate_paths(headless:bool = False, **kwargs): return False if finetune_image_folder != None: - log.info(f"Validating finetuning image folder path {finetune_image_folder} existance...") + log.info(f"Validating finetuning image folder path {finetune_image_folder} existence...") if not finetune_image_folder or not os.path.exists(finetune_image_folder): log.error(f"Image folder path '{finetune_image_folder}' is missing or does not exist") return False @@ -1308,7 +1308,7 @@ def validate_paths(headless:bool = False, **kwargs): if reg_data_dir != None: if reg_data_dir != "": - log.info(f"Validating regularisation data folder path {reg_data_dir} existance...") + log.info(f"Validating regularisation data folder path {reg_data_dir} existence...") if not os.path.exists(reg_data_dir): log.error("...regularisation folder does not exist") return False @@ -1320,7 +1320,7 @@ def validate_paths(headless:bool = False, **kwargs): log.info("Regularisation folder not specified, skipping validation") if output_dir != None: - log.info(f"Validating output folder path {output_dir} existance...") + log.info(f"Validating output folder path {output_dir} existence...") if output_dir == "" or not os.path.exists(output_dir): log.error("...output folder path is missing or invalid") return False @@ -1329,7 +1329,7 @@ def validate_paths(headless:bool = False, **kwargs): if logging_dir != None: if logging_dir != "": - log.info(f"Validating logging folder path {logging_dir} existance...") + log.info(f"Validating logging folder path {logging_dir} existence...") if not os.path.exists(logging_dir): log.error("...logging folder path is missing or invalid") return False @@ -1340,7 +1340,7 @@ def validate_paths(headless:bool = False, **kwargs): if lora_network_weights != None: if lora_network_weights != "": - log.info(f"Validating LoRA Network Weight file path {lora_network_weights} existance...") + log.info(f"Validating LoRA Network Weight file path {lora_network_weights} existence...") if not os.path.exists(lora_network_weights): log.error("...path is invalid") return False @@ -1351,7 +1351,7 @@ def validate_paths(headless:bool = False, **kwargs): if resume != None: if resume != "": - log.info(f"Validating model resume file path {resume} existance...") + log.info(f"Validating model resume file path {resume} existence...") if not os.path.exists(resume): log.error("...path is invalid") return False @@ -1362,7 +1362,7 @@ def validate_paths(headless:bool = False, **kwargs): if vae != None: if vae != "": - log.info(f"Validating VAE file path {vae} existance...") + log.info(f"Validating VAE file path {vae} existence...") if not os.path.exists(vae): log.error("...vae path is invalid") return False