From 5f348579d115e289422732b92fad3d5228deeb35 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 9 Jul 2023 12:46:35 +0800 Subject: [PATCH 1/2] Update sdxl_train.py --- sdxl_train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sdxl_train.py b/sdxl_train.py index 9cf20252d..06cbc5710 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -267,6 +267,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): unet.to(weight_dtype) text_encoder1.to(weight_dtype) text_encoder2.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: From d974959738d89b6d62e7dc60c6e10ee0d8f8ff4c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 9 Jul 2023 12:47:26 +0800 Subject: [PATCH 2/2] Update train_util.py for full_bf16 support --- library/train_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/train_util.py b/library/train_util.py index 62cd145e1..b9b4eecf8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2416,6 +2416,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" ) parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する") parser.add_argument( "--clip_skip", type=int,