Skip to content

Commit

Permalink
Improve files and folders validation (#2429)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed May 1, 2024
1 parent d32771b commit 36071cc
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 178 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,4 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
- To ensure cross-platform compatibility and security, the GUI now defaults to using "shell=False" when running subprocesses. This is based on documentation and should not cause issues on most platforms. However, some users have reported issues on specific platforms such as runpod and colab. PLease open an issue if you encounter any issues.
- Add support for custom LyCORIS toml config files. Simply type the path to the config file in the LyCORIS preset dropdown.
- Improve files and folders validation
169 changes: 64 additions & 105 deletions kohya_gui/common_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import math
import shutil
import toml

# Set up logging
log = setup_logging()
Expand Down Expand Up @@ -354,7 +355,7 @@ def update_my_data(my_data):
except ValueError:
# Handle the case where the string is not a valid float
my_data[key] = int(1)

for key in [
"max_train_steps",
]:
Expand Down Expand Up @@ -419,7 +420,7 @@ def update_my_data(my_data):
my_data["xformers"] = "xformers"
else:
my_data["xformers"] = "none"

# Convert use_wandb to log_with="wandb" if it is set to True
for key in ["use_wandb"]:
value = my_data.get(key)
Expand All @@ -430,17 +431,16 @@ def update_my_data(my_data):
except ValueError:
# Handle the case where the string is not a valid float
pass

my_data.pop(key, None)



# Replace the lora_network_weights key with network_weights keeping the original value
for key in ["lora_network_weights"]:
value = my_data.get(key) # Get original value
if value is not None: # Check if the key exists in the dictionary
value = my_data.get(key) # Get original value
if value is not None: # Check if the key exists in the dictionary
my_data["network_weights"] = value
my_data.pop(key, None)

return my_data


Expand Down Expand Up @@ -740,7 +740,6 @@ def add_pre_postfix(
postfix: str = "",
caption_file_ext: str = ".caption",
recursive: bool = False,

) -> None:
"""
Add prefix and/or postfix to the content of caption files within a folder.
Expand Down Expand Up @@ -1358,118 +1357,74 @@ def check_duplicate_filenames(
log.info("...valid")


def validate_paths(headless: bool = False, **kwargs: Optional[str]) -> bool:
"""
Validates the existence of specified paths and patterns for model training configurations.
This function checks for the existence of various directory paths and files provided as keyword arguments,
including model paths, data directories, output directories, and more. It leverages predefined default
models for validation and ensures directory creation if necessary.
Args:
headless (bool): A flag indicating if the function should run without requiring user input.
**kwargs (Optional[str]): Keyword arguments that represent various path configurations,
including but not limited to `pretrained_model_name_or_path`, `train_data_dir`,
and more.
def validate_file_path(file_path: str) -> bool:
if file_path == "":
return True
msg = f"Validating {file_path} existence..."
if not os.path.isfile(file_path):
log.error(f"{msg} FAILED: does not exist")
return False
log.info(f"{msg} SUCCESS")
return True

Returns:
bool: True if all specified paths are valid or have been successfully created; False otherwise.
"""

def validate_path(
path: Optional[str], path_type: str, create_if_missing: bool = False
) -> bool:
"""
Validates the existence of a path. If the path does not exist and `create_if_missing` is True,
attempts to create the directory.
Args:
path (Optional[str]): The path to validate.
path_type (str): Description of the path type for logging purposes.
create_if_missing (bool): Whether to create the directory if it does not exist.
Returns:
bool: True if the path is valid or has been created; False otherwise.
"""
if path:
log.info(f"Validating {path_type} path {path} existence...")
if os.path.exists(path):
log.info("...valid")
else:
if create_if_missing:
try:
os.makedirs(path, exist_ok=True)
log.info(f"...created folder at {path}")
return True
except Exception as e:
log.error(f"...failed to create {path_type} folder: {e}")
return False
else:
log.error(
f"...{path_type} path '{path}' is missing or does not exist"
)
return False
else:
log.info(f"{path_type} not specified, skipping validation")
def validate_folder_path(folder_path: str, can_be_written_to: bool = False) -> bool:
if folder_path == "":
return True

# Validates the model name or path against default models or existence as a local path
if not validate_model_path(kwargs.get("pretrained_model_name_or_path")):
msg = f"Validating {folder_path} existence{' and writability' if can_be_written_to else ''}..."
if not os.path.isdir(folder_path):
log.error(f"{msg} FAILED: does not exist")
return False
if can_be_written_to:
if not os.access(folder_path, os.W_OK):
log.error(f"{msg} FAILED: is not writable.")
return False
log.info(f"{msg} SUCCESS")
return True

# Validates the existence of specified directories or files, and creates them if necessary
for key, value in kwargs.items():
if key in ["output_dir", "logging_dir"]:
if not validate_path(value, key, create_if_missing=True):
return False
elif key in ["vae"]:
# Check if it matches the Hugging Face model pattern
if re.match(r"^[\w-]+\/[\w-]+$", value):
log.info("Checking vae... huggingface.co model, skipping validation")
else:
if not validate_path(value, key):
return False
else:
if key not in ["pretrained_model_name_or_path"]:
if not validate_path(value, key):
return False

def validate_toml_file(file_path: str) -> bool:
if file_path == "":
return True
msg = f"Validating toml {file_path} existence and validity..."
if not os.path.isfile(file_path):
log.error(f"{msg} FAILED: does not exist")
return False

try:
toml.load(file_path)
except:
log.error(f"{msg} FAILED: is not a valid toml file.")
return False
log.info(f"{msg} SUCCESS")
return True


def validate_model_path(pretrained_model_name_or_path: Optional[str]) -> bool:
def validate_model_path(pretrained_model_name_or_path: str) -> bool:
"""
Validates the pretrained model name or path against Hugging Face models or local paths.
Args:
pretrained_model_name_or_path (Optional[str]): The pretrained model name or path to validate.
pretrained_model_name_or_path (str): The pretrained model name or path to validate.
Returns:
bool: True if the path is a valid Hugging Face model or exists locally; False otherwise.
"""
from .class_source_model import default_models

if pretrained_model_name_or_path:
log.info(
f"Validating model file or folder path {pretrained_model_name_or_path} existence..."
)
msg = f"Validating {pretrained_model_name_or_path} existence..."

# Check if it matches the Hugging Face model pattern
if re.match(r"^[\w-]+\/[\w-]+$", pretrained_model_name_or_path):
log.info("...huggingface.co model, skipping validation")
elif pretrained_model_name_or_path not in default_models:
# If not one of the default models, check if it's a valid local path
if not os.path.exists(pretrained_model_name_or_path):
log.error(
f"...source model path '{pretrained_model_name_or_path}' is missing or does not exist"
)
return False
else:
log.info("...valid")
else:
log.info("...valid")
# Check if it matches the Hugging Face model pattern
if re.match(r"^[\w-]+\/[\w-]+$", pretrained_model_name_or_path):
log.info(f"{msg} SKIPPING: huggingface.co model")
elif pretrained_model_name_or_path in default_models:
log.info(f"{msg} SUCCESS")
else:
log.info("Model name or path not specified, skipping validation")
# If not one of the default models, check if it's a valid local path
if not os.path.exists(pretrained_model_name_or_path):
log.error(f"{msg} FAILED: is missing or does not exist")
return False
log.info(f"{msg} SUCCESS")

return True


Expand Down Expand Up @@ -1497,6 +1452,7 @@ def is_file_writable(file_path: str) -> bool:
# If an IOError occurs, the file cannot be written to
return False


def print_command_and_toml(run_cmd, tmpfilename):
log.warning(
"Here is the trainer command as a reference. It will not be executed:\n"
Expand All @@ -1514,16 +1470,19 @@ def print_command_and_toml(run_cmd, tmpfilename):
log.info(f"end of toml config file: {tmpfilename}")

save_to_file(command_to_run)



def validate_args_setting(input_string):
# Regex pattern to handle multiple conditions:
# - Empty string is valid
# - Single or multiple key/value pairs with exactly one space between pairs
# - No spaces around '=' and no spaces within keys or values
pattern = r'^(\S+=\S+)( \S+=\S+)*$|^$'
pattern = r"^(\S+=\S+)( \S+=\S+)*$|^$"
if re.match(pattern, input_string):
return True
else:
log.info(f"'{input_string}' is not a valid settings string.")
log.info("A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space.")
return False
log.info(
"A valid settings string must consist of one or more key/value pairs formatted as key=value, with no spaces around the equals sign or within the value. Multiple pairs should be separated by a space."
)
return False
64 changes: 49 additions & 15 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SaveConfigFile,
scriptdir,
update_my_data,
validate_paths,
validate_file_path, validate_folder_path, validate_model_path,
validate_args_setting,
)
from .class_accelerate_launch import AccelerateLaunch
Expand Down Expand Up @@ -511,23 +511,57 @@ def train_model(

log.info(f"Validating optimizer arguments...")
if not validate_args_setting(optimizer_args):
return
return TRAIN_BUTTON_VISIBLE

#
# Validate paths
#

if not validate_file_path(dataset_config):
return TRAIN_BUTTON_VISIBLE

if not validate_file_path(log_tracker_config):
return TRAIN_BUTTON_VISIBLE

if not validate_folder_path(logging_dir, can_be_written_to=True):
return TRAIN_BUTTON_VISIBLE

if not validate_folder_path(output_dir, can_be_written_to=True):
return TRAIN_BUTTON_VISIBLE

if not validate_model_path(pretrained_model_name_or_path):
return TRAIN_BUTTON_VISIBLE

if not validate_folder_path(reg_data_dir):
return TRAIN_BUTTON_VISIBLE

if not validate_file_path(resume):
return TRAIN_BUTTON_VISIBLE

if not validate_folder_path(train_data_dir):
return TRAIN_BUTTON_VISIBLE

if not validate_folder_path(vae):
return TRAIN_BUTTON_VISIBLE
#
# End of path validation
#

# This function validates files or folder paths. Simply add new variables containing file of folder path
# to validate below
if not validate_paths(
output_dir=output_dir,
pretrained_model_name_or_path=pretrained_model_name_or_path,
train_data_dir=train_data_dir,
reg_data_dir=reg_data_dir,
headless=headless,
logging_dir=logging_dir,
log_tracker_config=log_tracker_config,
resume=resume,
vae=vae,
dataset_config=dataset_config,
):
return TRAIN_BUTTON_VISIBLE
# if not validate_paths(
# dataset_config=dataset_config,
# headless=headless,
# log_tracker_config=log_tracker_config,
# logging_dir=logging_dir,
# output_dir=output_dir,
# pretrained_model_name_or_path=pretrained_model_name_or_path,
# reg_data_dir=reg_data_dir,
# resume=resume,
# train_data_dir=train_data_dir,
# vae=vae,
# ):
# return TRAIN_BUTTON_VISIBLE

if not print_only and check_if_model_exist(
output_name, output_dir, save_model_as, headless=headless
Expand Down
Loading

0 comments on commit 36071cc

Please sign in to comment.