From ec29c8b865124f49e26971930923406e7fa68b91 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Wed, 24 Jan 2024 17:53:54 -0500 Subject: [PATCH] Make TDT inference not require duration params (#8207) * Make TDT inference not require duration params Signed-off-by: Hainan Xu * addressed review comments Signed-off-by: Hainan Xu * stylistic change Signed-off-by: Hainan Xu --------- Signed-off-by: Hainan Xu Co-authored-by: Hainan Xu --- .../collections/asr/models/rnnt_bpe_models.py | 7 +++++ nemo/collections/asr/models/rnnt_models.py | 3 +++ .../asr/parts/submodules/rnnt_decoding.py | 26 ++++++++++++------- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 9fbace676fb7f..53676fa6d2e5f 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -459,6 +459,13 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) + + if loss_name == 'tdt': + decoding_cfg.durations = loss_kwargs.durations + elif loss_name == 'multiblank_rnnt': + decoding_cfg.big_blank_durations = loss_kwargs.big_blank_durations + self.decoding = RNNTBPEDecoding( decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index e297bc3f96ed0..0a4d94b690d00 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -78,6 +78,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if loss_name == 'tdt': num_classes = num_classes - self.joint.num_extra_outputs + self.cfg.decoding.durations = loss_kwargs.durations + elif loss_name == 'multiblank_rnnt': + self.cfg.decoding.big_blank_durations = loss_kwargs.big_blank_durations self.loss = RNNTLoss( num_classes=num_classes, diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 5c474ee21f8fd..e303f0d0295cd 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -208,15 +208,17 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): self.compute_timestamps = self.cfg.get('compute_timestamps', None) self.word_seperator = self.cfg.get('word_seperator', ' ') - if self.durations is not None: # this means it's a TDT model. + if self.durations is not None and self.durations != []: # this means it's a TDT model. if blank_id == 0: raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") - if self.big_blank_durations is not None: + if self.big_blank_durations is not None and self.big_blank_durations != []: raise ValueError("duration and big_blank_durations can't both be not None") if self.cfg.strategy not in ['greedy', 'greedy_batch']: raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models") - if self.big_blank_durations is not None: # this means it's a multi-blank model. + if ( + self.big_blank_durations is not None and self.big_blank_durations != [] + ): # this means it's a multi-blank model. if blank_id == 0: raise ValueError("blank_id must equal len(vocabs) for multi-blank RNN-T models") if self.cfg.strategy not in ['greedy', 'greedy_batch']: @@ -260,8 +262,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`") if self.cfg.strategy == 'greedy': - if self.big_blank_durations is None: - if self.durations is None: + if self.big_blank_durations is None or self.big_blank_durations == []: + if self.durations is None or self.durations == []: self.decoding = rnnt_greedy_decoding.GreedyRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -303,8 +305,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): ) elif self.cfg.strategy == 'greedy_batch': - if self.big_blank_durations is None: - if self.durations is None: + if self.big_blank_durations is None or self.big_blank_durations == []: + if self.durations is None or self.durations == []: self.decoding = rnnt_greedy_decoding.GreedyBatchedRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -522,10 +524,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp # RNN-T sample level is already preprocessed by implicit RNNT decoding # Simply remove any blank and possibly big blank tokens - if self.big_blank_durations is not None: # multi-blank RNNT + if self.big_blank_durations is not None and self.big_blank_durations != []: # multi-blank RNNT num_extra_outputs = len(self.big_blank_durations) prediction = [p for p in prediction if p < self.blank_id - num_extra_outputs] - elif self.durations is not None: # TDT model. + elif self.durations is not None and self.durations != []: # TDT model. prediction = [p for p in prediction if p < self.blank_id] else: # standard RNN-T prediction = [p for p in prediction if p != self.blank_id] @@ -1508,6 +1510,12 @@ class RNNTDecodingConfig: # can be used to change temperature for decoding temperature: float = 1.0 + # config for TDT decoding. + durations: Optional[List[int]] = field(default_factory=list) + + # config for multiblank decoding. + big_blank_durations: Optional[List[int]] = field(default_factory=list) + @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig):