diff --git a/sdxl_train.py b/sdxl_train.py index 613fe30b3..2cb80b6b3 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -412,7 +412,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1=text_encoder1 if train_text_encoder1 else None, text_encoder2=text_encoder2 if train_text_encoder2 else None, ) - ds_model = accelerator.prepare(ds_model) + # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: @@ -423,8 +426,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: