diff --git a/kohya_gui/lora_gui.py b/kohya_gui/lora_gui.py index be23f07f..29e34339 100644 --- a/kohya_gui/lora_gui.py +++ b/kohya_gui/lora_gui.py @@ -934,69 +934,117 @@ def train_model( # unet_lr = 0 if dataset_config: - log.info("Dataset config TOML file used; skipping calculations for total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, and max_train_steps.") + log.info( + "Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..." + ) + if max_train_steps > 0: + # calculate stop encoder training + if stop_text_encoder_training == 0: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training) + ) + + if lr_warmup != 0: + lr_warmup_steps = round( + float(int(lr_warmup) * int(max_train_steps) / 100) + ) + else: + lr_warmup_steps = 0 + else: + stop_text_encoder_training = 0 + lr_warmup_steps = 0 + + if max_train_steps == 0: + max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required." + else: + max_train_steps_info = f"Max train steps: {max_train_steps}" + else: - if not train_data_dir: - log.error("Train data directory is empty.") + if train_data_dir == "": + log.error("Train data dir is empty") return TRAIN_BUTTON_VISIBLE + # Get a list of all subfolders in train_data_dir subfolders = [ - f for f in os.listdir(train_data_dir) + f + for f in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, f)) ] + total_steps = 0 + # Loop through each subfolder and extract the number of repeats for folder in subfolders: try: - repeats_str = folder.split("_")[0] - repeats = int(repeats_str) - log.info(f"Folder '{folder}': {repeats} repeats found.") - except (ValueError, IndexError): - log.info(f"Skipping folder '{folder}': unable to extract repeat count.") - continue - - folder_path = os.path.join(train_data_dir, folder) - image_extensions = (".jpg", ".jpeg", ".png", ".webp") - num_images = len([ - file for file in os.listdir(folder_path) - if file.lower().endswith(image_extensions) - ]) - log.info(f"Folder '{folder}': {num_images} images found.") - - steps = repeats * num_images - log.info(f"Folder '{folder}': {num_images} images * {repeats} repeats = {steps} steps.") - total_steps += steps - - reg_factor = 2 if reg_data_dir else 1 - if reg_factor == 2: - log.warning("Regularization images are used; the number of required steps will be doubled.") - - log.info(f"Regularization factor: {reg_factor}") + # Extract the number of repeats from the folder name + repeats = int(folder.split("_")[0]) + log.info(f"Folder {folder}: {repeats} repeats found") + + # Count the number of images in the folder + num_images = len( + [ + f + for f, lower_f in ( + (file, file.lower()) + for file in os.listdir(os.path.join(train_data_dir, folder)) + ) + if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp")) + ] + ) - if max_train_steps == 0: - if train_batch_size == 0 or gradient_accumulation_steps == 0: - log.error("train_batch_size and gradient_accumulation_steps must be greater than zero.") - return TRAIN_BUTTON_VISIBLE - - max_train_steps = int(math.ceil( - total_steps / train_batch_size / gradient_accumulation_steps * epoch * reg_factor - )) - max_train_steps_info = ( - f"Calculated max_train_steps: ({total_steps} / {train_batch_size} / " - f"{gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}" - ) + log.info(f"Folder {folder}: {num_images} images found") + + # Calculate the total number of steps for this folder + steps = repeats * num_images + + # log.info the result + log.info(f"Folder {folder}: {num_images} * {repeats} = {steps} steps") + + total_steps += steps + + except ValueError: + # Handle the case where the folder name does not contain an underscore + log.info( + f"Error: '{folder}' does not contain an underscore, skipping..." + ) + + if reg_data_dir == "": + reg_factor = 1 else: - max_train_steps_info = f"Max train steps: {max_train_steps}" + log.warning( + "Regularisation images are used... Will double the number of steps required..." + ) + reg_factor = 2 - log.info(f"Total steps: {total_steps}") + log.info(f"Regulatization factor: {reg_factor}") - # Calculate stop_text_encoder_training - if max_train_steps is not None and max_train_steps > 0 and stop_text_encoder_training > 0: - stop_text_encoder_training = math.ceil( - max_train_steps * stop_text_encoder_training / 100 - ) - else: - stop_text_encoder_training = 0 + if max_train_steps == 0: + # calculate max_train_steps + max_train_steps = int( + math.ceil( + float(total_steps) + / int(train_batch_size) + / int(gradient_accumulation_steps) + * int(epoch) + * int(reg_factor) + ) + ) + max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}" + else: + if max_train_steps == 0: + max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required." + else: + max_train_steps_info = f"Max train steps: {max_train_steps}" + + # calculate stop encoder training + if stop_text_encoder_training == 0: + stop_text_encoder_training = 0 + else: + stop_text_encoder_training = math.ceil( + float(max_train_steps) / 100 * int(stop_text_encoder_training) + ) # Calculate lr_warmup_steps if lr_warmup_steps > 0: