Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conformer lm #54

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
421a410
Get dataset.py working..
danpovey Aug 21, 2021
cbe5ee1
Copy some files, will edit..
danpovey Aug 21, 2021
076a70b
Initial conformer refactoring, not nearly done
danpovey Aug 22, 2021
ea43b49
Remove BatchNorm, use LayerNorm
danpovey Aug 22, 2021
24d3a98
Merge remote-tracking branch 'upstream/master'
danpovey Aug 22, 2021
03ff4aa
Some progress on refactoring conformer code, it's in transformer.py o…
danpovey Aug 23, 2021
e0b04ba
Progress in testing
danpovey Aug 23, 2021
2fbe3b7
Add more testing; fix issue about channel dim of LayerNorm.
danpovey Aug 23, 2021
556fae5
Add testing for MaskedLmConformerEncoder
danpovey Aug 23, 2021
7856ab8
Test, and fix, TransformerDecoderLayerRelPos
danpovey Aug 23, 2021
5fecd24
Test, and fix, TransformerDecoderRelPos
danpovey Aug 23, 2021
26b5b5b
Get tests to work for MaskedLmConformer
danpovey Aug 23, 2021
13200d7
Merge remote-tracking branch 'upstream/master'
danpovey Aug 23, 2021
894be06
Update prepare.sh to create LM training data; add missed scripts loca…
danpovey Aug 23, 2021
c3a8727
Add train.py
danpovey Aug 23, 2021
7711fba
Fix bugs; first version that is running successfully.
danpovey Aug 23, 2021
9576d65
Various bug fixes
danpovey Aug 23, 2021
e6eefeb
Changes to dataset to prevent OOM on batches with short sentences
danpovey Aug 24, 2021
0d97e68
Version I am running...
danpovey Aug 24, 2021
a7b6110
Use collate_fn as class. harmless but not necessary without multiple…
danpovey Aug 25, 2021
d045831
Get dataset to work for empty input sentences; test it
danpovey Aug 25, 2021
ccf7bde
Add Foam optimizer; I used this from epoch 3.
danpovey Aug 28, 2021
573e058
Run in exp_2, with foam from start, knee_factor=5.0, initial_lrate=2e…
danpovey Aug 30, 2021
d313c27
Change configuration again.. not great performance.
danpovey Sep 7, 2021
56a88ba
Move to Gloam optimizer, exponential lrate
danpovey Sep 8, 2021
d0e5b9b
Change to exp_5, 1/sqrt(t) component.
danpovey Sep 9, 2021
3ce1de3
UPdates for new k2 version; change LR decay from 0.85 to 0.9
danpovey Sep 13, 2021
0cfa8c8
Merge remote-tracking branch 'upstream/master' into conformer_lm
danpovey Sep 24, 2021
8e650a5
Update egs/librispeech/ASR/conformer_lm/conformer.py
danpovey Sep 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,484 changes: 1,484 additions & 0 deletions egs/librispeech/ASR/conformer_lm/conformer.py

Large diffs are not rendered by default.

812 changes: 812 additions & 0 deletions egs/librispeech/ASR/conformer_lm/dataset.py

Large diffs are not rendered by default.

959 changes: 959 additions & 0 deletions egs/librispeech/ASR/conformer_lm/madam.py

Large diffs are not rendered by default.

156 changes: 156 additions & 0 deletions egs/librispeech/ASR/conformer_lm/test_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python3
# run with:
# python3 -m pytest test_conformer.py

import torch
import dataset # from .
from conformer import (
RelPosTransformerDecoder,
RelPosTransformerDecoderLayer,
MaskedLmConformer,
MaskedLmConformerEncoder,
MaskedLmConformerEncoderLayer,
RelPositionMultiheadAttention,
RelPositionalEncoding,
generate_square_subsequent_mask,
)

from torch.nn.utils.rnn import pad_sequence


def test_rel_position_multihead_attention():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)

x = torch.randn(N, T, C)
#pos_emb = torch.randn(1, 2*T-1, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb)


def test_masked_lm_conformer_encoder_layer():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)


x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask)


def test_masked_lm_conformer_encoder():
# Also tests RelPositionalEncoding
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads)
norm = torch.nn.LayerNorm(embed_dim)
encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4,
norm=norm)


x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
y = encoder(x, pos_emb, key_padding_mask=key_padding_mask)


def test_transformer_decoder_layer_rel_pos():
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)


x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)



def test_transformer_decoder_rel_pos():
embed_dim = 256
num_heads = 4
T = 25
N = 4
C = 256
pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0)
decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads)
decoder_norm = torch.nn.LayerNorm(embed_dim)
decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm)

x = torch.randn(N, T, C)
x, pos_emb = pos_emb_module(x)
x = x.transpose(0, 1) # (T, N, C)
key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T)
attn_mask = generate_square_subsequent_mask(T)
memory = torch.randn(T, N, C)
y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask)


def test_masked_lm_conformer():

num_classes = 87
d_model = 256

model = MaskedLmConformer(num_classes,d_model)


N = 31


(masked_src_symbols, src_symbols,
tgt_symbols, src_key_padding_mask,
tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2,
blank_sym=0)

# test forward() of MaskedLmConformer
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols,
src_key_padding_mask)
print("nll = ", nll)
loss = (nll * tgt_weights).sum()
print("loss = ", loss)



def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s, torch.device('cpu'))
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
13 changes: 13 additions & 0 deletions egs/librispeech/ASR/conformer_lm/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import dataset
import torch


train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt')
sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0)
a = iter(sampler)
print(str(next(a)))

collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True))
train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn)
x = iter(train_dl)
print(str(next(x)))
Loading