Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep model_avg on cpu #348

Merged
merged 10 commits into from
May 7, 2022
4 changes: 0 additions & 4 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,6 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank])
model.device = device

if rank == 0:
model_avg.to(device)
model_avg.device = device

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested with this change?
I wonder whether you need to do .to(torch.device('cpu')) when updating this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have tested it. model_avg is on cpu default, since it is a deepcopy of model before the line model.to(device).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

optimizer = Eve(model.parameters(), lr=params.initial_lr)

scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
Expand Down
4 changes: 3 additions & 1 deletion icefall/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,5 +467,7 @@ def average_state_dict(
uniqued_names = list(uniqued.values())
for k in uniqued_names:
state_dict_1[k] *= weight_1
state_dict_1[k] += state_dict_2[k] * weight_2
state_dict_1[k] += (
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
state_dict_1[k] *= scaling_factor