Skip to content

Commit

Permalink
common lr logging, set default None to ddp_timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 5, 2023
1 parent 96d877b commit 6231aa9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 44 deletions.
9 changes: 2 additions & 7 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)

loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
Expand Down
35 changes: 32 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,7 +2864,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)",
"--ddp_timeout",
type=int,
default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
)
parser.add_argument(
"--clip_skip",
Expand Down Expand Up @@ -3806,12 +3809,15 @@ def prepare_accelerator(args: argparse.Namespace):
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)

kwargs_handlers = (
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))],
kwargs_handlers=kwargs_handlers,
)
return accelerator

Expand Down Expand Up @@ -4401,6 +4407,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
return noise, noisy_latents, timesteps


def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names = []
if including_unet:
names.append("unet")
names.append("text_encoder1")
names.append("text_encoder2")

append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)


def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
lrs = lr_scheduler.get_last_lr()

for lr_index in range(len(lrs)):
name = names[lr_index]
logs["lr/" + name] = float(lrs[lr_index])

if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)


# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
Expand Down Expand Up @@ -4718,7 +4747,7 @@ def __init__(self):
self.loss_list: List[float] = []
self.loss_total: float = 0.0

def add(self, *, epoch:int, step: int, loss: float) -> None:
def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
Expand Down
37 changes: 10 additions & 27 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,33 +74,22 @@ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List


def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
lrs = lr_scheduler.get_last_lr()

lr_index = 0
names = []
block_index = 0
while lr_index < len(lrs):
while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = f"block{block_index}"
if block_lrs[block_index] == 0:
block_index += 1
continue
names.append(f"block{block_index}")
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = "text_encoder1"
names.append("text_encoder1")
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
name = "text_encoder2"
else:
raise ValueError(f"unexpected block_index: {block_index}")
names.append("text_encoder2")

block_index += 1

logs["lr/" + name] = float(lrs[lr_index])

if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)

lr_index += 1
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)


def train(args):
Expand Down Expand Up @@ -287,8 +276,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0

Expand Down Expand Up @@ -647,15 +636,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.logging_dir is not None:
logs = {"loss": current_loss}
if block_lrs is None:
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
else:
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs

accelerator.log(logs, step=global_step)

Expand Down
9 changes: 2 additions & 7 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,8 @@ def train(args):

current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)

loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
Expand Down

0 comments on commit 6231aa9

Please sign in to comment.