Skip to content

Commit

Permalink
Merge pull request #1064 from KohakuBlueleaf/fix-grad-sync
Browse files Browse the repository at this point in the history
Avoid grad sync on each step even when doing accumulation
  • Loading branch information
kohya-ss committed Jan 23, 2024
2 parents bea4362 + 711b40c commit 7a20df5
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,10 +842,11 @@ def remove_model(old_ckpt_name):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
Expand Down

0 comments on commit 7a20df5

Please sign in to comment.