Skip to content

Commit

Permalink
Revert some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Sep 13, 2024
1 parent 46b57ed commit f744479
Showing 1 changed file with 97 additions and 49 deletions.
146 changes: 97 additions & 49 deletions kohya_gui/lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,69 +934,117 @@ def train_model(
# unet_lr = 0

if dataset_config:
log.info("Dataset config TOML file used; skipping calculations for total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, and max_train_steps.")
log.info(
"Dataset config toml file used, skipping total_steps, train_batch_size, gradient_accumulation_steps, epoch, reg_factor, max_train_steps calculations..."
)
if max_train_steps > 0:
# calculate stop encoder training
if stop_text_encoder_training == 0:
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training)
)

if lr_warmup != 0:
lr_warmup_steps = round(
float(int(lr_warmup) * int(max_train_steps) / 100)
)
else:
lr_warmup_steps = 0
else:
stop_text_encoder_training = 0
lr_warmup_steps = 0

if max_train_steps == 0:
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"

else:
if not train_data_dir:
log.error("Train data directory is empty.")
if train_data_dir == "":
log.error("Train data dir is empty")
return TRAIN_BUTTON_VISIBLE

# Get a list of all subfolders in train_data_dir
subfolders = [
f for f in os.listdir(train_data_dir)
f
for f in os.listdir(train_data_dir)
if os.path.isdir(os.path.join(train_data_dir, f))
]

total_steps = 0

# Loop through each subfolder and extract the number of repeats
for folder in subfolders:
try:
repeats_str = folder.split("_")[0]
repeats = int(repeats_str)
log.info(f"Folder '{folder}': {repeats} repeats found.")
except (ValueError, IndexError):
log.info(f"Skipping folder '{folder}': unable to extract repeat count.")
continue

folder_path = os.path.join(train_data_dir, folder)
image_extensions = (".jpg", ".jpeg", ".png", ".webp")
num_images = len([
file for file in os.listdir(folder_path)
if file.lower().endswith(image_extensions)
])
log.info(f"Folder '{folder}': {num_images} images found.")

steps = repeats * num_images
log.info(f"Folder '{folder}': {num_images} images * {repeats} repeats = {steps} steps.")
total_steps += steps

reg_factor = 2 if reg_data_dir else 1
if reg_factor == 2:
log.warning("Regularization images are used; the number of required steps will be doubled.")

log.info(f"Regularization factor: {reg_factor}")
# Extract the number of repeats from the folder name
repeats = int(folder.split("_")[0])
log.info(f"Folder {folder}: {repeats} repeats found")

# Count the number of images in the folder
num_images = len(
[
f
for f, lower_f in (
(file, file.lower())
for file in os.listdir(os.path.join(train_data_dir, folder))
)
if lower_f.endswith((".jpg", ".jpeg", ".png", ".webp"))
]
)

if max_train_steps == 0:
if train_batch_size == 0 or gradient_accumulation_steps == 0:
log.error("train_batch_size and gradient_accumulation_steps must be greater than zero.")
return TRAIN_BUTTON_VISIBLE

max_train_steps = int(math.ceil(
total_steps / train_batch_size / gradient_accumulation_steps * epoch * reg_factor
))
max_train_steps_info = (
f"Calculated max_train_steps: ({total_steps} / {train_batch_size} / "
f"{gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
)
log.info(f"Folder {folder}: {num_images} images found")

# Calculate the total number of steps for this folder
steps = repeats * num_images

# log.info the result
log.info(f"Folder {folder}: {num_images} * {repeats} = {steps} steps")

total_steps += steps

except ValueError:
# Handle the case where the folder name does not contain an underscore
log.info(
f"Error: '{folder}' does not contain an underscore, skipping..."
)

if reg_data_dir == "":
reg_factor = 1
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"
log.warning(
"Regularisation images are used... Will double the number of steps required..."
)
reg_factor = 2

log.info(f"Total steps: {total_steps}")
log.info(f"Regulatization factor: {reg_factor}")

# Calculate stop_text_encoder_training
if max_train_steps is not None and max_train_steps > 0 and stop_text_encoder_training > 0:
stop_text_encoder_training = math.ceil(
max_train_steps * stop_text_encoder_training / 100
)
else:
stop_text_encoder_training = 0
if max_train_steps == 0:
# calculate max_train_steps
max_train_steps = int(
math.ceil(
float(total_steps)
/ int(train_batch_size)
/ int(gradient_accumulation_steps)
* int(epoch)
* int(reg_factor)
)
)
max_train_steps_info = f"max_train_steps ({total_steps} / {train_batch_size} / {gradient_accumulation_steps} * {epoch} * {reg_factor}) = {max_train_steps}"
else:
if max_train_steps == 0:
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
else:
max_train_steps_info = f"Max train steps: {max_train_steps}"

# calculate stop encoder training
if stop_text_encoder_training == 0:
stop_text_encoder_training = 0
else:
stop_text_encoder_training = math.ceil(
float(max_train_steps) / 100 * int(stop_text_encoder_training)
)

# Calculate lr_warmup_steps
if lr_warmup_steps > 0:
Expand Down

0 comments on commit f744479

Please sign in to comment.