From 4fcd49b843b5144b3162640a8d69119d44c9e206 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 1 May 2024 08:27:58 -0400 Subject: [PATCH] Improve files and folders validation --- README.md | 1 + kohya_gui/common_gui.py | 169 +++++++++++------------------ kohya_gui/dreambooth_gui.py | 64 ++++++++--- kohya_gui/finetune_gui.py | 51 +++++++-- kohya_gui/lora_gui.py | 93 ++++++++++------ kohya_gui/textual_inversion_gui.py | 61 ++++++++--- 6 files changed, 261 insertions(+), 178 deletions(-) diff --git a/README.md b/README.md index 39a61d13c..c6ef0f8ad 100644 --- a/README.md +++ b/README.md @@ -437,3 +437,4 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG - To ensure cross-platform compatibility and security, the GUI now defaults to using "shell=False" when running subprocesses. This is based on documentation and should not cause issues on most platforms. However, some users have reported issues on specific platforms such as runpod and colab. PLease open an issue if you encounter any issues. - Add support for custom LyCORIS toml config files. Simply type the path to the config file in the LyCORIS preset dropdown. +- Improve files and folders validation diff --git a/kohya_gui/common_gui.py b/kohya_gui/common_gui.py index 73a7f8a8d..43f247667 100644 --- a/kohya_gui/common_gui.py +++ b/kohya_gui/common_gui.py @@ -14,6 +14,7 @@ import json import math import shutil +import toml # Set up logging log = setup_logging() @@ -354,7 +355,7 @@ def update_my_data(my_data): except ValueError: # Handle the case where the string is not a valid float my_data[key] = int(1) - + for key in [ "max_train_steps", ]: @@ -419,7 +420,7 @@ def update_my_data(my_data): my_data["xformers"] = "xformers" else: my_data["xformers"] = "none" - + # Convert use_wandb to log_with="wandb" if it is set to True for key in ["use_wandb"]: value = my_data.get(key) @@ -430,17 +431,16 @@ def update_my_data(my_data): except ValueError: # Handle the case where the string is not a valid float pass - + my_data.pop(key, None) - - + # Replace the lora_network_weights key with network_weights keeping the original value for key in ["lora_network_weights"]: - value = my_data.get(key) # Get original value - if value is not None: # Check if the key exists in the dictionary + value = my_data.get(key) # Get original value + if value is not None: # Check if the key exists in the dictionary my_data["network_weights"] = value my_data.pop(key, None) - + return my_data @@ -740,7 +740,6 @@ def add_pre_postfix( postfix: str = "", caption_file_ext: str = ".caption", recursive: bool = False, - ) -> None: """ Add prefix and/or postfix to the content of caption files within a folder. @@ -1358,118 +1357,74 @@ def check_duplicate_filenames( log.info("...valid") -def validate_paths(headless: bool = False, **kwargs: Optional[str]) -> bool: - """ - Validates the existence of specified paths and patterns for model training configurations. - - This function checks for the existence of various directory paths and files provided as keyword arguments, - including model paths, data directories, output directories, and more. It leverages predefined default - models for validation and ensures directory creation if necessary. - - Args: - headless (bool): A flag indicating if the function should run without requiring user input. - **kwargs (Optional[str]): Keyword arguments that represent various path configurations, - including but not limited to `pretrained_model_name_or_path`, `train_data_dir`, - and more. +def validate_file_path(file_path: str) -> bool: + if file_path == "": + return True + msg = f"Validating {file_path} existence..." + if not os.path.isfile(file_path): + log.error(f"{msg} FAILED: does not exist") + return False + log.info(f"{msg} SUCCESS") + return True - Returns: - bool: True if all specified paths are valid or have been successfully created; False otherwise. - """ - def validate_path( - path: Optional[str], path_type: str, create_if_missing: bool = False - ) -> bool: - """ - Validates the existence of a path. If the path does not exist and `create_if_missing` is True, - attempts to create the directory. - - Args: - path (Optional[str]): The path to validate. - path_type (str): Description of the path type for logging purposes. - create_if_missing (bool): Whether to create the directory if it does not exist. - - Returns: - bool: True if the path is valid or has been created; False otherwise. - """ - if path: - log.info(f"Validating {path_type} path {path} existence...") - if os.path.exists(path): - log.info("...valid") - else: - if create_if_missing: - try: - os.makedirs(path, exist_ok=True) - log.info(f"...created folder at {path}") - return True - except Exception as e: - log.error(f"...failed to create {path_type} folder: {e}") - return False - else: - log.error( - f"...{path_type} path '{path}' is missing or does not exist" - ) - return False - else: - log.info(f"{path_type} not specified, skipping validation") +def validate_folder_path(folder_path: str, can_be_written_to: bool = False) -> bool: + if folder_path == "": return True - - # Validates the model name or path against default models or existence as a local path - if not validate_model_path(kwargs.get("pretrained_model_name_or_path")): + msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..." + if not os.path.isdir(folder_path): + log.error(f"{msg} FAILED: does not exist") return False + if can_be_written_to: + if not os.access(folder_path, os.W_OK): + log.error(f"{msg} FAILED: is not writable.") + return False + log.info(f"{msg} SUCCESS") + return True - # Validates the existence of specified directories or files, and creates them if necessary - for key, value in kwargs.items(): - if key in ["output_dir", "logging_dir"]: - if not validate_path(value, key, create_if_missing=True): - return False - elif key in ["vae"]: - # Check if it matches the Hugging Face model pattern - if re.match(r"^[\w-]+\/[\w-]+$", value): - log.info("Checking vae... huggingface.co model, skipping validation") - else: - if not validate_path(value, key): - return False - else: - if key not in ["pretrained_model_name_or_path"]: - if not validate_path(value, key): - return False - +def validate_toml_file(file_path: str) -> bool: + if file_path == "": + return True + msg = f"Validating toml {file_path} existence and validity..." + if not os.path.isfile(file_path): + log.error(f"{msg} FAILED: does not exist") + return False + + try: + toml.load(file_path) + except: + log.error(f"{msg} FAILED: is not a valid toml file.") + return False + log.info(f"{msg} SUCCESS") return True -def validate_model_path(pretrained_model_name_or_path: Optional[str]) -> bool: +def validate_model_path(pretrained_model_name_or_path: str) -> bool: """ Validates the pretrained model name or path against Hugging Face models or local paths. Args: - pretrained_model_name_or_path (Optional[str]): The pretrained model name or path to validate. + pretrained_model_name_or_path (str): The pretrained model name or path to validate. Returns: bool: True if the path is a valid Hugging Face model or exists locally; False otherwise. """ from .class_source_model import default_models - if pretrained_model_name_or_path: - log.info( - f"Validating model file or folder path {pretrained_model_name_or_path} existence..." - ) + msg = f"Validating {pretrained_model_name_or_path} existence..." - # Check if it matches the Hugging Face model pattern - if re.match(r"^[\w-]+\/[\w-]+$", pretrained_model_name_or_path): - log.info("...huggingface.co model, skipping validation") - elif pretrained_model_name_or_path not in default_models: - # If not one of the default models, check if it's a valid local path - if not os.path.exists(pretrained_model_name_or_path): - log.error( - f"...source model path '{pretrained_model_name_or_path}' is missing or does not exist" - ) - return False - else: - log.info("...valid") - else: - log.info("...valid") + # Check if it matches the Hugging Face model pattern + if re.match(r"^[\w-]+\/[\w-]+$", pretrained_model_name_or_path): + log.info(f"{msg} SKIPPING: huggingface.co model") + elif pretrained_model_name_or_path in default_models: + log.info(f"{msg} SUCCESS") else: - log.info("Model name or path not specified, skipping validation") + # If not one of the default models, check if it's a valid local path + if not os.path.exists(pretrained_model_name_or_path): + log.error(f"{msg} FAILED: is missing or does not exist") + return False + log.info(f"{msg} SUCCESS") + return True @@ -1497,6 +1452,7 @@ def is_file_writable(file_path: str) -> bool: # If an IOError occurs, the file cannot be written to return False + def print_command_and_toml(run_cmd, tmpfilename): log.warning( "Here is the trainer command as a reference. It will not be executed:\n" @@ -1514,16 +1470,19 @@ def print_command_and_toml(run_cmd, tmpfilename): log.info(f"end of toml config file: {tmpfilename}") save_to_file(command_to_run) - + + def validate_args_setting(input_string): # Regex pattern to handle multiple conditions: # - Empty string is valid # - Single or multiple key/value pairs with exactly one space between pairs # - No spaces around '=' and no spaces within keys or values - pattern = r'^(\S+=\S+)( \S+=\S+)*$|^$' + pattern = r"^(\S+=\S+)( \S+=\S+)*$|^$" if re.match(pattern, input_string): return True else: log.info(f"'{input_string}' is not a valid settings string.") - log.info("A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space.") - return False \ No newline at end of file + log.info( + "A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space." + ) + return False diff --git a/kohya_gui/dreambooth_gui.py b/kohya_gui/dreambooth_gui.py index 1bafece60..2dcaece8f 100644 --- a/kohya_gui/dreambooth_gui.py +++ b/kohya_gui/dreambooth_gui.py @@ -17,7 +17,7 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, + validate_file_path, validate_folder_path, validate_model_path, validate_args_setting, ) from .class_accelerate_launch import AccelerateLaunch @@ -511,23 +511,57 @@ def train_model( log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): - return + return TRAIN_BUTTON_VISIBLE + + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(reg_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(train_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(vae): + return TRAIN_BUTTON_VISIBLE + # + # End of path validation + # # This function validates files or folder paths. Simply add new variables containing file of folder path # to validate below - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - train_data_dir=train_data_dir, - reg_data_dir=reg_data_dir, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - vae=vae, - dataset_config=dataset_config, - ): - return TRAIN_BUTTON_VISIBLE + # if not validate_paths( + # dataset_config=dataset_config, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # reg_data_dir=reg_data_dir, + # resume=resume, + # train_data_dir=train_data_dir, + # vae=vae, + # ): + # return TRAIN_BUTTON_VISIBLE if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless=headless diff --git a/kohya_gui/finetune_gui.py b/kohya_gui/finetune_gui.py index d3411f093..68c7ef055 100644 --- a/kohya_gui/finetune_gui.py +++ b/kohya_gui/finetune_gui.py @@ -18,7 +18,7 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, + validate_file_path, validate_folder_path, validate_model_path, validate_args_setting, ) from .class_accelerate_launch import AccelerateLaunch @@ -556,17 +556,46 @@ def train_model( if train_dir != "" and not os.path.exists(train_dir): os.mkdir(train_dir) - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - finetune_image_folder=image_folder, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - dataset_config=dataset_config, - ): + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(image_folder): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): return TRAIN_BUTTON_VISIBLE + + # + # End of path validation + # + + # if not validate_paths( + # dataset_config=dataset_config, + # finetune_image_folder=image_folder, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # resume=resume, + # ): + # return TRAIN_BUTTON_VISIBLE if not print_only and check_if_model_exist( output_name, output_dir, save_model_as, headless diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index b350b2606..a7b2816c2 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -19,7 +19,7 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, + validate_file_path, validate_folder_path, validate_model_path, validate_toml_file, validate_args_setting, ) from .class_accelerate_launch import AccelerateLaunch @@ -682,23 +682,6 @@ def train_model( gr.Button(visible=False or headless), gr.Textbox(value=train_state_value), ] - - if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES: - if not os.path.exists(LyCORIS_preset): - output_message( - msg=f"LyCORIS preset file {LyCORIS_preset} does not exist.", - headless=headless, - ) - return TRAIN_BUTTON_VISIBLE - else: - try: - toml.load(LyCORIS_preset) - except: - output_message( - msg=f"LyCORIS preset file {LyCORIS_preset} is not a valid toml file.", - headless=headless, - ) - return TRAIN_BUTTON_VISIBLE if executor.is_running(): log.error("Training is already running. Can't start another training session.") @@ -708,27 +691,69 @@ def train_model( log.info(f"Validating lr scheduler arguments...") if not validate_args_setting(lr_scheduler_args): - return + return TRAIN_BUTTON_VISIBLE log.info(f"Validating optimizer arguments...") if not validate_args_setting(optimizer_args): - return - - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - train_data_dir=train_data_dir, - reg_data_dir=reg_data_dir, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - vae=vae, - network_weights=network_weights, - dataset_config=dataset_config, - ): return TRAIN_BUTTON_VISIBLE + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if LyCORIS_preset not in LYCORIS_PRESETS_CHOICES: + if not validate_toml_file(LyCORIS_preset): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(network_weights): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(reg_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(train_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(vae): + return TRAIN_BUTTON_VISIBLE + + # + # End of path validation + # + + # if not validate_paths( + # dataset_config=dataset_config, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # network_weights=network_weights, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # reg_data_dir=reg_data_dir, + # resume=resume, + # train_data_dir=train_data_dir, + # vae=vae, + # ): + # return TRAIN_BUTTON_VISIBLE + if int(bucket_reso_steps) < 1: output_message( msg="Bucket resolution steps need to be greater than 0", diff --git a/kohya_gui/textual_inversion_gui.py b/kohya_gui/textual_inversion_gui.py index af4ca3f66..720ce35ba 100644 --- a/kohya_gui/textual_inversion_gui.py +++ b/kohya_gui/textual_inversion_gui.py @@ -19,7 +19,7 @@ SaveConfigFile, scriptdir, update_my_data, - validate_paths, + validate_file_path, validate_folder_path, validate_model_path, validate_args_setting ) from .class_accelerate_launch import AccelerateLaunch @@ -514,19 +514,54 @@ def train_model( if not validate_args_setting(optimizer_args): return - if not validate_paths( - output_dir=output_dir, - pretrained_model_name_or_path=pretrained_model_name_or_path, - train_data_dir=train_data_dir, - reg_data_dir=reg_data_dir, - headless=headless, - logging_dir=logging_dir, - log_tracker_config=log_tracker_config, - resume=resume, - vae=vae, - dataset_config=dataset_config, - ): + # + # Validate paths + # + + if not validate_file_path(dataset_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(log_tracker_config): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(logging_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(output_dir, can_be_written_to=True): + return TRAIN_BUTTON_VISIBLE + + if not validate_model_path(pretrained_model_name_or_path): return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(reg_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_file_path(resume): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(train_data_dir): + return TRAIN_BUTTON_VISIBLE + + if not validate_folder_path(vae): + return TRAIN_BUTTON_VISIBLE + + # + # End of path validation + # + + # if not validate_paths( + # dataset_config=dataset_config, + # headless=headless, + # log_tracker_config=log_tracker_config, + # logging_dir=logging_dir, + # output_dir=output_dir, + # pretrained_model_name_or_path=pretrained_model_name_or_path, + # reg_data_dir=reg_data_dir, + # resume=resume, + # train_data_dir=train_data_dir, + # vae=vae, + # ): + # return TRAIN_BUTTON_VISIBLE if token_string == "": output_message(msg="Token string is missing", headless=headless)