Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam search algorithm implementation for TDT models #10903

Merged
merged 126 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 118 commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
0f0958b
initial commit
Jul 24, 2024
adb08dd
add: default beam search implementation
Jul 24, 2024
30d0599
fix: changed to removing duplicate hypothesis in separate function
Jul 25, 2024
796bbc7
fix: changed to cartesian product in choosing best hyp
Jul 30, 2024
47bbffa
fix: minor fixes in comments
Jul 30, 2024
1b64767
add: maes decoding strategy
Jul 30, 2024
7551e32
add: durations filtering in maes, lm fusion in progress
Jul 30, 2024
417eb2d
fix: refactored, added comments, command line args, finalized
Jul 31, 2024
4b83af3
fix: removed prints
Jul 31, 2024
51ab336
add: docs
Jul 31, 2024
f88dcca
Merge branch 'main' into beam_search
lilithgrigoryan Jul 31, 2024
1115529
Apply isort and black reformatting
lilithgrigoryan Jul 31, 2024
d5cad08
fix: minor fix
Aug 12, 2024
2d8b455
fix: rm beam_size=1 exception, rm duplicates check, fix error handling
Aug 14, 2024
90e452c
fix: error handling
Aug 14, 2024
3c1dd89
merge
Aug 14, 2024
98f4d53
Apply isort and black reformatting
lilithgrigoryan Aug 14, 2024
20dfd07
fix: removed evaluations file
Aug 14, 2024
1495f0d
Merge branch 'beam_search' of https://github.com/lilithgrigoryan/NeMo…
Aug 14, 2024
0d472cf
rn: blank scoring
Aug 15, 2024
71f0607
clean up
Aug 15, 2024
2f1f495
rm: blank scoring and duration beam size
Aug 19, 2024
e748171
Apply isort and black reformatting
lilithgrigoryan Aug 19, 2024
b75ff04
fix: removed durations_beam_size from default beam search
Aug 19, 2024
5cfbd2d
merge
Aug 19, 2024
6b6fa1f
add: logaddexp
Aug 23, 2024
3e95406
rm: prefix search
Aug 23, 2024
9a386e2
rn: nested loop over extensions
Aug 23, 2024
9d8aeeb
fix: bug with caching
Aug 23, 2024
08aecb7
rm: topk on durations
Aug 23, 2024
d1ce7e9
add: restored prefix search
Aug 23, 2024
ad96664
Apply isort and black reformatting
lilithgrigoryan Aug 23, 2024
19f0fbc
clean up
Aug 23, 2024
707a327
Merge branch 'beam_search' of https://github.com/lilithgrigoryan/NeMo…
Aug 23, 2024
bbfe224
fix: fixed comments
Aug 23, 2024
6a288c8
refactored duplicate merging
Aug 25, 2024
73b55b4
changes batch scoring
Aug 28, 2024
a752bbe
refactored rnnt batch scoring
Aug 29, 2024
7afccf3
alsd first working
Aug 30, 2024
2b1aa91
refactored
Aug 30, 2024
3592451
clean up
Aug 30, 2024
5d48a3f
remove stacking operations
Sep 18, 2024
c11d09f
fixes im base class
Sep 18, 2024
1540047
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Sep 18, 2024
271e2a6
clean up
Sep 18, 2024
51074bf
Apply isort and black reformatting
lilithgrigoryan Sep 18, 2024
05380d0
remove potentially uninitialized local variable
Sep 18, 2024
98ebb9f
Merge branch 'lgrigoryan/rm-redundant-calculations' of https://github…
Sep 18, 2024
08955c7
merge
Oct 7, 2024
4a06795
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Oct 7, 2024
61044bc
default beam search minor fixes
Oct 16, 2024
f1504b2
add test, fix maes timesteps
Oct 16, 2024
3d4d913
rm file
Oct 16, 2024
64e8cfd
cleanuo
Oct 16, 2024
9a6c940
rm file
Oct 16, 2024
4632a56
clean up
Oct 16, 2024
a34c0b4
Apply isort and black reformatting
lilithgrigoryan Oct 16, 2024
7ce4b2a
clean up
Oct 16, 2024
a8ee1f4
merge
Oct 16, 2024
b6dd217
fix comments
Oct 16, 2024
0de50f1
merge main
Nov 5, 2024
406e1ab
add ngram lm test
Nov 5, 2024
0076b9c
Apply isort and black reformatting
lilithgrigoryan Nov 5, 2024
49c208e
fix maes_num_steps=1
Nov 5, 2024
49ca284
fix kenlm model path
Nov 5, 2024
77f5b88
fix kenlm model full path
Nov 5, 2024
7c063a9
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 5, 2024
0726de1
merge
Nov 5, 2024
8684cff
Apply isort and black reformatting
lilithgrigoryan Nov 5, 2024
e12e124
made requested changes
Nov 5, 2024
7a61367
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 5, 2024
629b053
merge after isort
Nov 5, 2024
3f8e1c9
add prints to test
Nov 5, 2024
560d505
Apply isort and black reformatting
lilithgrigoryan Nov 5, 2024
19dfab1
add Kenlm to asr requirements
Nov 6, 2024
64bb51b
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 6, 2024
2c62bdf
remove prints in tests
Nov 6, 2024
0f40f28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
c7da895
add kenlm to test requirements
Nov 6, 2024
35317d9
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 6, 2024
61ab404
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
7f031e2
rm kenlm from link, add package-name
Nov 6, 2024
39018d2
merge
Nov 6, 2024
269dc87
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
f23bdce
rm second kenlm installation
Nov 6, 2024
ea47d23
rm kenlm from dependencies make test optional
Nov 6, 2024
52a89f2
Apply isort and black reformatting
lilithgrigoryan Nov 6, 2024
7f9c1dd
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 6, 2024
cb87610
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 6, 2024
9038849
fix in test
Nov 6, 2024
b7111f4
fix in test
Nov 6, 2024
c900892
Apply isort and black reformatting
lilithgrigoryan Nov 6, 2024
3cb46ef
fix comments
Nov 6, 2024
ccf7933
Apply isort and black reformatting
lilithgrigoryan Nov 6, 2024
20c3d4c
add comments
Nov 6, 2024
0d6a027
add comments
Nov 6, 2024
97584c1
splitted docstrings
Nov 6, 2024
80bf636
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 6, 2024
fe2d5b9
Apply isort and black reformatting
lilithgrigoryan Nov 6, 2024
2fbf375
add comments
Nov 6, 2024
3bd4f99
splitted docstrings
Nov 6, 2024
d592af9
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 6, 2024
b1df71e
Apply isort and black reformatting
lilithgrigoryan Nov 6, 2024
24cacc4
add comments
Nov 6, 2024
bb42ab5
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 6, 2024
70bb42b
Apply isort and black reformatting
lilithgrigoryan Nov 6, 2024
a1d348c
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 6, 2024
fb720df
Merge branch 'main' into lgrigoryan/tdt_beam_search
lilithgrigoryan Nov 6, 2024
028ea39
fixes to python3 type annotations
Nov 7, 2024
712ec52
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 7, 2024
6379324
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 7, 2024
3b58359
Apply isort and black reformatting
lilithgrigoryan Nov 7, 2024
b9d4459
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 12, 2024
008d507
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 12, 2024
cf402c3
merging
Nov 12, 2024
6fa3a7b
merging
Nov 12, 2024
0d544b5
fix in return type
Nov 12, 2024
f1d8932
Apply isort and black reformatting
lilithgrigoryan Nov 12, 2024
74542a8
fix test
Nov 12, 2024
0070039
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 12, 2024
677b88d
Apply isort and black reformatting
lilithgrigoryan Nov 12, 2024
0669d3d
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 12, 2024
2169fd0
Merge branch 'lgrigoryan/tdt_beam_search' of https://github.com/NVIDI…
Nov 12, 2024
28dce00
rm time_idx
Nov 13, 2024
ef99181
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
Nov 13, 2024
c61f01b
fix comments to python3 style
Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/asr/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ RNNT Decoding
:show-inheritance:
:members:

TDT Decoding
~~~~~~~~~~~~~

.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyTDTInfer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these classes!

:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedTDTInfer
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.submodules.tdt_beam_decoding.BeamTDTInfer
:show-inheritance:
:members:

Hypotheses
~~~~~~~~~~

Expand Down
56 changes: 46 additions & 10 deletions nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@


def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:
"""
Packs a list of hypotheses into a tensor and prepares decoder states.

This function takes a list of token sequences (hypotheses) and converts
it into a tensor format. If any decoder states are on the GPU, they
are moved to the CPU. Additionally, the function removes any timesteps
with a value of -1 from the sequences.

Args:
hypotheses (list): A list of token sequences representing hypotheses.

Returns:
list: A list of packed hypotheses in tensor format.
"""
for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis
hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long)

Expand All @@ -69,6 +83,18 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:


def _states_to_device(dec_state, device='cpu'):
"""
Transfers decoder states to the specified device.

This function moves the provided decoder states to the specified device (e.g., 'cpu' or 'cuda').

Args:
dec_state (Tensor): The decoder states to be transferred.
device (str): The target device to which the decoder states should be moved. Defaults to 'cpu'.

Returns:
Tensor: The decoder states on the specified device.
"""
if torch.is_tensor(dec_state):
dec_state = dec_state.to(device)

Expand Down Expand Up @@ -106,15 +132,17 @@ class BeamRNNTInfer(Typing):
however the time required for the search also grows steadily.

`tsd` - time synchronous decoding. Please refer to the paper:
[Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
[Alignment-Length Synchronous Decoding for RNN Transducer]
(https://ieeexplore.ieee.org/document/9053040)
for details on the algorithm implemented.

Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions.
For longer sequences, T is greater, and can therefore take a long time for beams to obtain
good results. This also requires greater memory to execute.

`alsd` - alignment-length synchronous decoding. Please refer to the paper:
[Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
[Alignment-Length Synchronous Decoding for RNN Transducer]
(https://ieeexplore.ieee.org/document/9053040)
for details on the algorithm implemented.

Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth
Expand All @@ -127,7 +155,8 @@ class BeamRNNTInfer(Typing):
For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.

`maes` = modified adaptive expansion searcn. Please refer to the paper:
[Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505)
[Accelerating RNN Transducer Inference via Adaptive Expansion Search]
(https://ieeexplore.ieee.org/document/9250505)

Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the
number of expansions (for tokens) required per timestep. The number of expansions can usually
Expand Down Expand Up @@ -169,10 +198,10 @@ class BeamRNNTInfer(Typing):
and affects the speed of inference since large values will perform large beam search in the next step.

maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.
The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v])
where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be
predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for
expansion apart from the "most likely" candidate.
The default (2.3) is selected from the paper. It performs a comparison
(max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob
is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which
can be potential candidates for expansion apart from the "most likely" candidate.
Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed
but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value,
thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally
Expand All @@ -182,7 +211,7 @@ class BeamRNNTInfer(Typing):

preserve_alignments: Bool flag which preserves the history of alignments generated during
beam decoding (sample). When set to true, the Hypothesis will contain
the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1).
the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1)

The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary.
Expand Down Expand Up @@ -1456,8 +1485,11 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu
return lm_score, next_state

def set_decoding_type(self, decoding_type: str):

# Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
"""
Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
Args:
decoding_type: decoding type
"""
# TOKEN_OFFSET for BPE-based models
if decoding_type == 'subword':
from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET
Expand All @@ -1467,6 +1499,10 @@ def set_decoding_type(self, decoding_type: str):

@dataclass
class BeamRNNTInferConfig:
"""
Beam RNNT Inference config.
"""

beam_size: int
search_type: str = 'default'
score_norm: bool = True
Expand Down
Loading
Loading