From c3bb03253fe0260a1f1abd8189e18a788862a54b Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 14:45:00 +0800 Subject: [PATCH 1/8] keep model_avg on cpu --- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 4 ---- icefall/checkpoint.py | 4 +++- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 147bcf658f..cc61b3b32b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -867,10 +867,6 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - if rank == 0: - model_avg.to(device) - model_avg.device = device - optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 5b562ccc87..ba3823ffc3 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -467,5 +467,7 @@ def average_state_dict( uniqued_names = list(uniqued.values()) for k in uniqued_names: state_dict_1[k] *= weight_1 - state_dict_1[k] += state_dict_2[k] * weight_2 + state_dict_1[k] += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) state_dict_1[k] *= scaling_factor From ae50acad8bce1752311ac4908d1222dac5ec70e7 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 15:35:09 +0800 Subject: [PATCH 2/8] explicitly convert model_avg to cpu --- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cc61b3b32b..e3121935ad 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -867,6 +867,10 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device + if rank == 0: + model_avg = model_avg.to(torch.device("cpu")) + model_avg.device = device + optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From 3c405fb6fdd461711181ce3dd5df71e6150e3583 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 15:37:30 +0800 Subject: [PATCH 3/8] minor fix --- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index e3121935ad..2f2fbd8380 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -868,7 +868,7 @@ def run(rank, world_size, args): model.device = device if rank == 0: - model_avg = model_avg.to(torch.device("cpu")) + model_avg.to(torch.device("cpu")) model_avg.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) From a72048be3e7b942946de5bc245d4e52043cb0a33 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 15:58:52 +0800 Subject: [PATCH 4/8] remove device convertion for model_avg --- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 2f2fbd8380..cc61b3b32b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -867,10 +867,6 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank]) model.device = device - if rank == 0: - model_avg.to(torch.device("cpu")) - model_avg.device = device - optimizer = Eve(model.parameters(), lr=params.initial_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From dd439b190675f0d1d87816c1635f70d69d997f6c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 22:03:49 +0800 Subject: [PATCH 5/8] modify usage of the model device in train.py --- .../ASR/pruned_transducer_stateless4/train.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cc61b3b32b..accf0aa843 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -479,7 +479,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, @@ -529,7 +529,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, @@ -553,7 +553,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -609,7 +613,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -643,7 +647,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, @@ -865,7 +869,6 @@ def run(rank, world_size, args): if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - model.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) @@ -986,7 +989,7 @@ def remove_short_and_long_utt(c: Cut): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, From b1e9d2186d6771b27220cfc947e5b01cab637498 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Fri, 6 May 2022 22:20:14 +0800 Subject: [PATCH 6/8] change model.device to next(model.parameters()).device for decoding --- .../ASR/pruned_transducer_stateless2/beam_search.py | 10 +++++----- .../ASR/pruned_transducer_stateless4/decode.py | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ad492aaa5a..fc1285dc72 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -276,7 +276,7 @@ def greedy_search( context_size = model.decoder.context_size unk_id = getattr(model, "unk_id", blank_id) - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, device=device, dtype=torch.int64 @@ -350,7 +350,7 @@ def greedy_search_batch( assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) - device = model.device + device = next(model.parameters()).device batch_size = encoder_out.size(0) T = encoder_out.size(1) @@ -580,7 +580,7 @@ def modified_beam_search( blank_id = model.decoder.blank_id unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device B = [HypothesisList() for _ in range(batch_size)] for i in range(batch_size): B[i].add( @@ -705,7 +705,7 @@ def _deprecated_modified_beam_search( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device T = encoder_out.size(1) @@ -813,7 +813,7 @@ def beam_search( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index e066629052..025ebd7bc6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -250,7 +250,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 @@ -560,7 +560,6 @@ def main(): model.to(device) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) From 0bee5d058a4b5420213d97600491661e519620de Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 7 May 2022 10:32:57 +0800 Subject: [PATCH 7/8] assert params.start_epoch>0 --- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index accf0aa843..2886b8dc76 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -125,8 +125,8 @@ def get_parser(): "--start-epoch", type=int, default=1, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -861,6 +861,7 @@ def run(rank, world_size, args): # model_avg is only used with rank 0 model_avg = copy.deepcopy(model) + assert params.start_epoch > 0 checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) From 537617302aa640af5aefdc7ba94df3cd32616615 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 7 May 2022 10:36:04 +0800 Subject: [PATCH 8/8] assert params.start_epoch>0, params.start_epoch --- egs/librispeech/ASR/pruned_transducer_stateless4/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 2886b8dc76..4ff69d521b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -861,7 +861,7 @@ def run(rank, world_size, args): # model_avg is only used with rank 0 model_avg = copy.deepcopy(model) - assert params.start_epoch > 0 + assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg )