Skip to content

Commit

Permalink
Accept None as an argument to decoder_lengths in GreedyBatchedCTCInfe…
Browse files Browse the repository at this point in the history
…r::forward (#9246)

* Accept None as an argument to decoder_lengths in GreedyBatchedCTCInfer::forward

GreedyCTCInfer::forward already allowed for this, so they did not
implement the exact same interface. Now, they do.

Also warn about not passing in the decoder_lengths argument. It is
likely an error on the user's part not to pass it in explicitly.

Signed-off-by: Daniel Galvez <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Log warning only once for sanity.

Signed-off-by: Daniel Galvez <[email protected]>

---------

Signed-off-by: Daniel Galvez <[email protected]>
Signed-off-by: titu1994 <[email protected]>
Co-authored-by: titu1994 <[email protected]>
Co-authored-by: Somshubra Majumdar <[email protected]>
  • Loading branch information
3 people authored May 22, 2024
1 parent 26f566b commit 212023c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
23 changes: 18 additions & 5 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _states_to_device(dec_state, device='cpu'):
return dec_state


_DECODER_LENGTHS_NONE_WARNING = "Passing in decoder_lengths=None for CTC decoding is likely to be an error, since it is unlikely that each element of your batch has exactly the same length. decoder_lengths will default to decoder_output.shape[0]."


class GreedyCTCInfer(Typing, ConfidenceMethodMixin):
"""A greedy CTC decoder.
Expand Down Expand Up @@ -148,7 +151,7 @@ def __init__(
def forward(
self,
decoder_output: torch.Tensor,
decoder_lengths: torch.Tensor,
decoder_lengths: Optional[torch.Tensor],
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand All @@ -167,6 +170,9 @@ def forward(
mode=logging_mode.ONCE,
)

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)

with torch.inference_mode():
hypotheses = []
# Process each sequence independently
Expand Down Expand Up @@ -213,7 +219,7 @@ def forward(
return (packed_result,)

@torch.no_grad()
def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: Optional[torch.Tensor]):
# x: [T, D]
# out_len: [seq_len]

Expand Down Expand Up @@ -243,7 +249,7 @@ def _greedy_decode_logprobs(self, x: torch.Tensor, out_len: torch.Tensor):
return hypothesis

@torch.no_grad()
def _greedy_decode_labels(self, x: torch.Tensor, out_len: torch.Tensor):
def _greedy_decode_labels(self, x: torch.Tensor, out_len: Optional[torch.Tensor]):
# x: [T]
# out_len: [seq_len]

Expand Down Expand Up @@ -370,7 +376,7 @@ def __init__(
def forward(
self,
decoder_output: torch.Tensor,
decoder_lengths: torch.Tensor,
decoder_lengths: Optional[torch.Tensor],
):
"""Returns a list of hypotheses given an input batch of the encoder hidden embedding.
Output token is generated auto-repressively.
Expand All @@ -383,11 +389,18 @@ def forward(
Returns:
packed list containing batch number of sentences (Hypotheses).
"""

input_decoder_lengths = decoder_lengths

if decoder_lengths is None:
logging.warning(_DECODER_LENGTHS_NONE_WARNING, mode=logging_mode.ONCE)
decoder_lengths = torch.tensor([decoder_output.shape[1]], dtype=torch.long).expand(decoder_output.shape[0])

if decoder_output.ndim == 2:
hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths)
else:
hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, decoder_lengths)
packed_result = pack_hypotheses(hypotheses, input_decoder_lengths)
return (packed_result,)

@torch.no_grad()
Expand Down
18 changes: 14 additions & 4 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme
@pytest.mark.parametrize('alignments', [False, True])
@pytest.mark.parametrize('timestamps', [False, True])
@pytest.mark.parametrize('preserve_frame_confidence', [False, True])
def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence):
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_logprobs(
self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence, length_is_none
):
cfg = CTCBPEDecodingConfig(
strategy='greedy',
preserve_alignments=alignments,
Expand All @@ -219,7 +222,10 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps,
# that we always handle at least a few blanks.
input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000
length = torch.randint(low=1, high=T, size=[B])
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand All @@ -242,7 +248,8 @@ def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps,

@pytest.mark.unit
@pytest.mark.parametrize('timestamps', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
@pytest.mark.parametrize('length_is_none', [False, True])
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batch'
Expand All @@ -256,7 +263,10 @@ def test_batched_decoding_labels(self, tmp_tokenizer, timestamps):
# at least a few blanks.
input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size
input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size
length = torch.randint(low=1, high=T, size=[B])
if length_is_none:
length = None
else:
length = torch.randint(low=1, high=T, size=[B])

with torch.inference_mode():
hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor(
Expand Down

0 comments on commit 212023c

Please sign in to comment.