diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index ad492aaa5a..fc1285dc72 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -276,7 +276,7 @@ def greedy_search( context_size = model.decoder.context_size unk_id = getattr(model, "unk_id", blank_id) - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, device=device, dtype=torch.int64 @@ -350,7 +350,7 @@ def greedy_search_batch( assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) - device = model.device + device = next(model.parameters()).device batch_size = encoder_out.size(0) T = encoder_out.size(1) @@ -580,7 +580,7 @@ def modified_beam_search( blank_id = model.decoder.blank_id unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device B = [HypothesisList() for _ in range(batch_size)] for i in range(batch_size): B[i].add( @@ -705,7 +705,7 @@ def _deprecated_modified_beam_search( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device T = encoder_out.size(1) @@ -813,7 +813,7 @@ def beam_search( unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index e066629052..025ebd7bc6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -250,7 +250,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 @@ -560,7 +560,6 @@ def main(): model.to(device) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 147bcf658f..4ff69d521b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -125,8 +125,8 @@ def get_parser(): "--start-epoch", type=int, default=1, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -479,7 +479,7 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, @@ -529,7 +529,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, @@ -553,7 +553,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -609,7 +613,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -643,7 +647,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, @@ -857,6 +861,7 @@ def run(rank, world_size, args): # model_avg is only used with rank 0 model_avg = copy.deepcopy(model) + assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) @@ -865,11 +870,6 @@ def run(rank, world_size, args): if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) - model.device = device - - if rank == 0: - model_avg.to(device) - model_avg.device = device optimizer = Eve(model.parameters(), lr=params.initial_lr) @@ -990,7 +990,7 @@ def remove_short_and_long_utt(c: Cut): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 5b562ccc87..ba3823ffc3 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -467,5 +467,7 @@ def average_state_dict( uniqued_names = list(uniqued.values()) for k in uniqued_names: state_dict_1[k] *= weight_1 - state_dict_1[k] += state_dict_2[k] * weight_2 + state_dict_1[k] += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) state_dict_1[k] *= scaling_factor