Skip to content

Commit

Permalink
Merge pull request #1200 from BootsofLagrangian/deep-speed
Browse files Browse the repository at this point in the history
Fix sdxl_train.py in deepspeed branch
  • Loading branch information
kohya-ss authored Mar 20, 2024
2 parents fbb98f1 + d945602 commit a35e7bd
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
ds_model = accelerator.prepare(ds_model)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -423,8 +426,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)

optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down

0 comments on commit a35e7bd

Please sign in to comment.