Skip to content

Commit

Permalink
explicitly convert model_avg to cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei committed May 6, 2022
1 parent c3bb032 commit ae50aca
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,10 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank])
model.device = device

if rank == 0:
model_avg = model_avg.to(torch.device("cpu"))
model_avg.device = device

optimizer = Eve(model.parameters(), lr=params.initial_lr)

scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
Expand Down

0 comments on commit ae50aca

Please sign in to comment.