From 421a41027a42aa3b9b490787bdaba0eea1568624 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Aug 2021 18:23:46 +0800 Subject: [PATCH 01/26] Get dataset.py working.. --- egs/librispeech/ASR/conformer_lm/dataset.py | 783 ++++++++++++++++++++ egs/librispeech/ASR/local/download_lm.py | 1 + egs/librispeech/ASR/prepare.sh | 5 + 3 files changed, 789 insertions(+) create mode 100644 egs/librispeech/ASR/conformer_lm/dataset.py diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py new file mode 100644 index 0000000000..75f603d9d4 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -0,0 +1,783 @@ +import torch +import torch.distributed as dist +import k2 +import _k2 +import sentencepiece as spm +from typing import Optional, List, Tuple + + + +class LmDataset(torch.utils.data.Dataset): + """ + Torch dataset for language modeling data. This is a map-style dataset. + The indices are integers. + """ + def __init__(self, sentences: k2.RaggedInt, + words: k2.RaggedInt): + super(LmDataset, self).__init__() + self.sentences = sentences + self.words = words + + + def __len__(self): + # Total size on axis 0, == num sentences + return self.sentences.tot_size(0) + + def __getitem__(self, i: int): + """ + Return the i'th sentence, as a list of ints (representing BPE pieces, without + bos or eos symbols). + """ + # It would be nicer if we could just return self.sentences[i].tolist(), but + # for now that operator on k2.RaggedInt is not implemented. + row_splits = self.sentences.row_splits(1) + (begin, end) = row_splits[i:i+2].tolist() + sentence = self.sentences.values()[begin:end] + return k2.index(self.words, sentence).values().tolist() + + +def load_train_test_lm_dataset(archive_fn: str, + test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]: + """ + returns (train_lm_dataset, test_lm_dataset) + """ + + d = torch.load(archive_fn) + words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces + sentences = d['data'] # a k2.RaggedInt + + with torch.random.fork_rng(devices=[]): + g = torch.manual_seed(0) + num_sentences = sentences.tot_size(0) + # probably the generator (g) argument to torch.randperm below is not necessary. + sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32) + sentences = k2.index(sentences, sentence_perm) + + num_test_sentences = int(num_sentences * test_proportion) + + axis=0 + train_sents = _k2.ragged_int_arange(sentences, axis, + num_test_sentences, num_sentences) + test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences) + + return LmDataset(train_sents, words), LmDataset(test_sents, words) + + +def mask_and_pad(sentence: List[int], + seq_len: int, + bos_sym: int, + eos_sym: int, + blank_sym: int, + mask_proportion: float, + padding_proportion: float, + inv_mask_length: float, + unmasked_weight: float) -> Tuple[List[int], List[int], List[int], List[float]]: + """ + This function contains part of the logic of collate_fn, broken out. It is responsible + for inserting masking and padding into the sequence `sentence`. Most of the arguments + are documented for `collate_fn` below. + Other args: + sentence: The original sentence to be masked and padded. + seq_len: The desired length of the lists to be returned + bos_sym, eos_sym, blank_sym, mask_proportion, + padding_proportion, inv_mask_length, unmasked_weight: see their documentation + as args to `collate_fn` below. + + + Return: a tuple (src, masked_src, tgt, weight, randomizable, attn_mask), all lists of length `seq_len`, + where: + `src` is: [bos] + [the sentence after inserting blanks in place of padding + after regions to be masked] + [eos] + [blank padding to seq_len]. + `src_masked` is as `src` but the masked regions have their values replaced with blank, + i.e. they are actually masked. + `tgt` is: [the original sentence, without masking] + [eos] + [blank] + [blank padding to seq_len] + `weight` is the weight at the nnet output, which is: `unmasked_weight` for un-masked + positions, 1.0 for masked and padded positions, and 0.0 for positions that + correspond to blank-padding after the final [eos]. + `randomizable` is a bool that is True for positions where the symbol in + in `src_masked` is not bos or eos or blank. + `attn_mask` is a bool that is False for positions in `src` and `src_masked` that + are between the initial [bos] and final [eos] inclusive; and True for + positions after the final [eos]. + """ + sent_len = len(sentence) + assert sent_len + 3 <= seq_len + + for w in sentence: + assert w not in [bos_sym, eos_sym, blank_sym] + + num_mask = int(torch.binomial(count=torch.tensor([sent_len * 1.0]), + prob=torch.tensor([mask_proportion])).item()) + num_pad = int(torch.poisson(torch.tensor([sent_len * padding_proportion])).item()) + # Ensure the total length after bos, padding of masked sequences, and eos, is + # no greater than seq_len + num_pad -= max(0, sent_len + 2 + num_pad - seq_len) + + if num_mask + num_pad == 0: + num_mask += 1 + + # num_split_points is the number of times we split the (masked+padded) + # region, so the total number of (masking+padding) subsequences will be + # num_split_points + 1. If num_mask positions are masked, then the + # remaining number of words is `sent_len - num_mask`, and any two + # masked regions must have at least one non-masked word between them, + # so num_split_points == number of masked regions - 1, must be + # no greater than `sent_len - num_mask`. The formula about + # mask_proportion * inv_mask_length / (1.0 - mask_proportion) + # is what's required (I think) so that inv_mask_length is the expected + # length of masked regions. + num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), + prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) + assert num_split_points <= sent_len - num_mask + assert isinstance(num_split_points, int) + + def split_into_subseqs(length: int , num_subseqs: int) -> List[int]: + """Splits a sequence of `length` items into `num_subseqs` possibly-empty + subsequences. The length distributions are geometric, not Poisson, i.e. + we choose the split locations with uniform probability rather than + randomly assigning each word to one subsequences. This gives us more + shorter/longer subsequences. + Require num_subseqs > 0 + """ + boundaries = [0] + sorted(torch.randint(low=0, high=length + 1, size=(num_subseqs - 1,)).tolist()) + [length] + return [ boundaries[i + 1] - boundaries[i] for i in range(num_subseqs) ] + + mask_lengths = split_into_subseqs(num_mask, num_split_points + 1) + pad_lengths = split_into_subseqs(num_pad, num_split_points + 1) + # mask_pad_lengths contains only the (mask, pad) length pairs for which mask + pad > 0. + # From this point we only refer to the mask_pad_lengths. + mask_pad_lengths = [ (mask, pad) for (mask, pad) in zip(mask_lengths, pad_lengths) if mask+pad > 0 ] + num_subseqs = len(mask_pad_lengths) + assert num_subseqs > 0 + + # Now figure out how to distribute these subsequences throughout the actual + # sentence. The subsequences, if there are more than one, must not touch, + # i.e. there must be an actual word in between each subsequence, where the + # number of such "mandatory" words equals num_subseqs - 1. We also have to + # subtract `num_mask` words, since obviously the masked words cannot separate + # the masked regions. + reduced_len = sent_len - num_mask - (num_subseqs - 1) + assert reduced_len >= 0 + # unmasked_lengths will be the lengths of the un-masked regions between the masked + # regions. + unmasked_lengths = split_into_subseqs(reduced_len, num_subseqs + 1) + for i in range(1, num_subseqs): + # Unmasked regions between masked regions must have length at least 1, + # we add 1 to unmasked regions that are not initial/final. + unmasked_lengths[i] = unmasked_lengths[i] + 1 + assert sum(unmasked_lengths) + sum(mask_lengths) == sent_len + + + # src_positions will be: for each position in the masked+padded sentence, + # the corresponding position in the source sentence `sentence`; or -1 + # if this was padding. + src_positions = [] + # `masked` will be: for each position in the masked+padded sentence, True if + # it was masked and False otherwise. (Note: it is False for padding + # locations, although this will not matter in the end). + masked = [] + + cur_pos = 0 # current position in source sentence + for i in range(num_subseqs + 1): + for j in range(unmasked_lengths[i]): + src_positions.append(cur_pos) + masked.append(False) + cur_pos += 1 + if i < num_subseqs: + (mask_len, pad_len) = mask_pad_lengths[i] + for j in range(mask_len): + src_positions.append(cur_pos) + masked.append(True) + cur_pos += 1 + for j in range(pad_len): + src_positions.append(-1) + masked.append(False) + assert cur_pos == len(sentence) + + + src = [] + src_masked = [] + tgt = [] + weight = [] + randomizable = [] + + src.append(bos_sym) + src_masked.append(bos_sym) + randomizable.append(False) + for i, src_pos in enumerate(src_positions): + is_masked = masked[i] + if src_pos >= 0: + src_word = sentence[src_pos] + src_masked.append(blank_sym if masked[i] else src_word) + src.append(src_word) + tgt.append(src_word) + weight.append(1.0 if masked[i] else unmasked_weight) + randomizable.append(not masked[i]) + else: + # Padding inside a masked region + src_masked.append(blank_sym) + src.append(blank_sym) + tgt.append(blank_sym) + weight.append(1.0) + randomizable.append(False) + src.append(eos_sym) + src_masked.append(eos_sym) + tgt.append(eos_sym) + weight.append(unmasked_weight) + tgt.append(blank_sym) + weight.append(0.0) + randomizable.append(False) + + attn_mask = ([False] * len(src)) + ([True] * (seq_len - len(src))) + + for i in range(seq_len - len(src)): + src.append(blank_sym) + src_masked.append(blank_sym) + tgt.append(blank_sym) + weight.append(0.0) + randomizable.append(False) + + return (src, src_masked, tgt, weight, randomizable, attn_mask) + + +# dataset.mask_and_pad(list(range(10, 20)), seq_len=16, bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, inv_mask_length=0.33, unmasked_weight=0.444) + +# dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444) + +def collate_fn(sentences: List[List[int]], + bos_sym: int, + eos_sym: int, + blank_sym: int, + mask_proportion: float = 0.15, + padding_proportion: float = 0.15, + randomize_proportion: float = 0.05, + inv_mask_length: float = 0.25, + unmasked_weight: float = 0.25, + debug: bool = False) -> Tuple[torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, + torch.Tensor]: + """ + Caution, this is not the collate_fn we give directly to the dataloader, + we give it a lambda: collate_fn=(lambda x: dataset.collate_fn(x, [other args])) + This formats a list-of-lists-of-int into 5 Tensors, explained below. + The key thing is that we mask out subsequences of random length within + these sentences, and force the network to predict the masked-out + subsequences (which have blanks appended to them to prevent the model + from knowing the exact length of the sequences it has to predict). + So it's like BERT but at the level of sequences rather than individual + words. + + Args: + bos_sym: the integer id of the beginning-of-sentence symbol, e.g. 2. + Is allowed be the same as eos_sym (we are not necessarily + saying it will work best that way). + eos_sym: the integer id of the end-of-sentence symbol, e.g. 2. + blank_sym: the integer id of the blank symbol, e.g. 0 or 1. + mask_proportion: The proportion of words in each sentence that + are masked, interpreted as (roughly) the probability of any given + word being masked, although the masked locations will + tend to be in contiguous sequences (they are not independent). + padding_proportion: Like mask_proportion, but determines the + number of extra, blank symbols that are inserted as padding + at the end of masked regions (this ensures that the model + cannot know exactly how many words need to be inserted in + any given masked region. + randomize_proportion: The probability with which we replace + words that were not masked with randomly chosen words. + Like BERT, this is intended to force the model to predict + something reasonable at non-masked positions, and to make + this task harder than simply repeating the input. + inv_mask_length: This number determines how many separate + sub-sequences the (masked + padded) proportion of a sentence is split up + into, interpreted as the inverse of the expected length of + each *masked* region. + unmasked_weight: The weight to be applied to the log-likelihoods of + un-masked positions in sentences (predicting un-masked + positions is not completely trivial if randomize_proportion > 0). + Will be reflected in the returned tgt_weights tensor. + + Returns a tuple (masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, + tgt_weights), + all with 2 axes and the same shape: (num_sent, seq_len). + Their dtypes will be, respectively, + (torch.int64, torch.int64, + torch.int64, torch.bool, + torch.float) + masked_src_symbols: The sentences, with bos_symbol prepended and eos_symbol + appended, masked regions (including padding) replaced with blank, + and `randomize_proportion` non-masked symbols replaced with + symbols randomly taken from elsewhere in the sentences of this + minibatch. Then padded to a fixed length with blank. + src_symbols: Like masked_src_symbols, except with the masked symbols replaced + with the original symbols (but the padding that follows each + masked sub-sequence will still be blank) + tgt_symbols: The original sentences, with eos_symbol appended, and then + padded with blank to the same length as masked_symbols and + src_symbols. + src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to + account for all the sentence lengths not being identical + (makes each sentence's processing independent of seq_len). + Tensor of Bool of shape (num_sent, seq_len), with True + for masked positions (these are the blanks that follow the + eos_symbol in masked_src_symbols), False for un-masked positions. + tgt_weights: Weights that will be applied to the log-probabilities at + the output of the network. Will have 1.0 in positions + in `tgt_symbols` that were masked (including blank + padding at the end of masked regions), `unmasked_weight` + in other positions in the original sentences (including + terminating eos_symbol); and 0.0 in the remaining positions + corresponding to blank padding after the ends of + sentences. + """ + assert blank_sym not in [bos_sym, eos_sym] + max_sent_len = max([ len(s) for s in sentences]) + + typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion)) + + # The following formula gives roughly 1 standard deviation above where we'd + # expect the maximum sentence length to be with masking and padding.. we use + # this as a hard upper limit, to prevent outliers from affecting the batch + # size too much. We use this as the size `seq_len`. + # The "+ 4" is to ensure there is always room for the BOS, EOS and at least + # two padding symbols. + seq_len = max_sent_len + 4 + typical_mask_and_pad + int(typical_mask_and_pad ** 0.5) + + + # srcs, srcs_masked, tgts and weights will be lists of the lists returned + # from `mask_and_pad`, one per sentence. + srcs = [] + srcs_masked = [] + tgts = [] + weights = [] + randomizables = [] + attn_masks = [] + for s in sentences: + (src, src_masked, tgt, + weight, randomizable, + attn_mask) = mask_and_pad(s, seq_len, bos_sym, eos_sym, + blank_sym, mask_proportion, padding_proportion, + inv_mask_length, unmasked_weight) + srcs.append(src) + srcs_masked.append(src_masked) + tgts.append(tgt) + weights.append(weight) + randomizables.append(randomizable) + attn_masks.append(attn_mask) + + src_symbols = torch.tensor(srcs, dtype=torch.int64) + masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64) + tgt_symbols = torch.tensor(tgts, dtype=torch.int64) + src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool) + tgt_weights = torch.tensor(weights, dtype=torch.float) + + attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist() + while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists. + attn_mask_sum.pop() + if len(attn_mask_sum) < seq_len: + seq_len = len(attn_mask_sum) + (src_symbols, masked_src_symbols, + tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len], + tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len], + tgt_weights[:,:seq_len]) + + if randomize_proportion > 0.0: + randomizable_tensor = torch.tensor(randomizables, dtype=torch.bool) + randomizable_indexes = torch.nonzero(randomizable_tensor) # (num_randomizable, 2) + num_randomizable = randomizable_indexes.shape[0] + + to_randomize_indexes = torch.nonzero(torch.rand(num_randomizable) < randomize_proportion, as_tuple=True)[0] + num_to_randomize = to_randomize_indexes.numel() + + # older versions of torch don't have tensor_split, so fake a simplified version of it. + # we'd be calling it as xxx.tensor_split(dim=1) if really in torc. + def tensor_split(t): + return (t[:,0], t[:,1]) + + random_src_locations = torch.randperm(num_randomizable)[:num_to_randomize] + + random_symbols = src_symbols[tensor_split(randomizable_indexes[random_src_locations])] + random_indexes_tuple= tensor_split(randomizable_indexes[to_randomize_indexes]) + src_symbols[random_indexes_tuple] = random_symbols + masked_src_symbols[random_indexes_tuple] = random_symbols + + + # I set this to true and tested with: + # python3 -c 'import dataset; dataset.collate_fn(sentences=[ list(range(100, 200)), list(range(300, 450)), list(range(500,600))], bos_sym=1, eos_sym=2, blank_sym=0, mask_proportion=0.2, padding_proportion=0.2, randomize_proportion=0.05, inv_mask_length=0.33, unmasked_weight=0.444)' + #.. and ran a few times to check the values printed looked about right, and that no assertions failed. + if debug: + check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym, + unmasked_weight, + masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, tgt_weights) + return (masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, tgt_weights) + + + +def check_collated_tensors(sentences: List[List[int]], + bos_sym: int, + eos_sym: int, + blank_sym: int, + unmasked_weight: float, + masked_src_symbols, src_symbols, + tgt_symbols, src_attn_mask, + tgt_weights): + """ + This function checks the output of collate_fn, consider it test code. Please see + the documentation of collate_fn to understand the args. + """ + for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights: + assert t.shape == masked_src_symbols.shape + + tot_positions = src_symbols.numel() + + masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = ( + masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(), + src_attn_mask.tolist(), tgt_weights.tolist()) + assert len(sentences) == len(masked_src_symbols) + + tot_masked_positions = 0 + tot_padded_positions = 0 + tot_unmasked_positions = 0 # all un-masked, non-blank postions, including eos + tot_randomized_positions = 0 + num_masked_subseqs = 0 + tot_symbols = 0 # original symbols in sentences, no bos/eos + + assert unmasked_weight > 0.001 # or this test code won't work.. + + for i in range(len(sentences)): + reconstructed_sent = list(filter(lambda x: x not in [bos_sym,eos_sym,blank_sym], tgt_symbols[i])) + if sentences[i] != reconstructed_sent: + print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}") + (masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i], + tgt_symbols[i], src_attn_mask[i], tgt_weights[i]) + + assert src[0] == masked_src[0] == bos_sym + for j in range(len(masked_src)): + assert masked_src[j] == blank_sym or masked_src[j] == src[j] + + if src[j] not in [bos_sym, eos_sym, blank_sym]: + tot_symbols += 1 + + if j > 0: + assert (src[j] == eos_sym) == (masked_src[j] == eos_sym) == (tgt[j-1] == eos_sym) + if masked_src[j] == blank_sym: # masked or padding of masked subseq, or post-eos padding.. + assert src[j] == tgt[j - 1] # masked symbols are not randomized. + assert weights[j - 1] in [0.0, 1.0] # 0.0 for final blank padding + if weights[j - 1] == 1.0: # Not final blank padding... + if tgt[j - 1] == blank_sym: + tot_padded_positions += 1 + else: + tot_masked_positions += 1 + if masked_src[j + 1] != blank_sym: + num_masked_subseqs += 1 + else: + assert weights[j - 1] == 0 or abs(weights[j-1] - unmasked_weight) < 0.001 + if abs(weights[j - 1]-unmasked_weight) < 0.001: + tot_unmasked_positions += 1 + if tgt[j - 1] != src[j]: + tot_randomized_positions += 1 + + if src_mask[j]: # if masked.. + assert src[j] == blank_sym + + assert tot_symbols == sum(len(x) for x in sentences) + + assert tot_unmasked_positions + tot_masked_positions == tot_symbols + len(sentences) + + print(f"{tot_unmasked_positions} + {tot_masked_positions} == {tot_symbols} + {len(sentences)}") + print(f"tot_symbols / tot_positions = {tot_symbols/tot_positions} (rest is bos,eos,padding)") + + print(f"Masking/tot_symbols = {tot_masked_positions/tot_symbols}, Padding/tot_symbols = {tot_padded_positions/tot_symbols}") + print(f"Randomization/tot_non_masked_symbols = {tot_randomized_positions/(tot_symbols-tot_masked_positions)}") + print(f"Mean masking length = {tot_masked_positions/num_masked_subseqs}, Mean padding length = {tot_padded_positions/num_masked_subseqs}") + + + +# This shows some useful code about the BPE encoding. +# import sentencepiece as spm +# sp = spm.SentencePieceProcessor() +# sp.load(bpe_model_fn) # bpe.model +# sp.GetPieceSize(..) +# sp.Decode(...) +# sp.Encode(...) + + +# import dataset +# import torch +# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') + + +# train_dl = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True, collate_fn=(lambda x: train.collate_fn(x))) +# x = iter(train_dl) +# str(next(x)) +# '[ [ 10 38 651 593 3 1343 31 780 6 4172 112 788 1696 24 289 24 3 403 6 4493 162 92 71 328 417 217 338 14 5 3 1876 154 21 23 2237 43 3 1535 92 71 2816 7 1031 31 2318 92 2528 4806 14 206 3 954 1373 6 525 4 631 447 2639 ] [ 1014 336 171 209 795 10 16 90 27 787 139 53 45 2817 ] [ 11 980 51 22 1748 14 91 105 363 428 6 8 2887 3305 2525 2297 70 3 4651 6 27 282 335 426 134 292 5 193 3 539 2250 584 127 ] [ 9 3 1858 4 18 2257 4 6 41 748 10 304 7 229 83 2793 4 9 981 7 1484 33 3 103 7 539 5 477 3195 18 64 39 82 1034 6 3 4128 ] [ 17 147 22 7 708 60 133 174 105 4111 4 6 3 1384 65 50 1051 9 2953 6 3 461 180 1142 23 5 36 888 8 131 173 390 78 23 266 2822 715 46 182 65 22 1739 33 3 700 1450 14 233 4 ] [ 80 10 16 67 279 7 1827 264 96 3 187 2851 2108 ] [ 1473 48 106 227 9 160 2011 4 674 ] [ 3 954 762 29 85 228 33 8 940 40 4952 36 486 390 595 3 81 225 6 1440 125 346 134 296 126 419 1017 3824 4 8 179 184 11 33 580 1861 ] [ 30 22 245 15 117 8 2892 28 1204 145 7 3 236 3417 6 3 3839 5 3106 155 198 30 228 2555 46 15 32 41 747 72 9 25 977 ] [ 222 466 6 3157 ] ]' +# +# or: +# import k2 +# k2.ragged.to_list(next(x)) +# [shows something similar]. +# +# You'd really do something like: +# for epoch in range(max_epochs): +# for minibatch in train_dl: + + +# .. How to process data? Suppose we have a sentence like [259, 278, 45, 11, 303, 1319, 34, 15, 396, 3435, 7, 44]. +# +# First: we randomly choose one or more starting positins for a masked segment. +# Each sentence must have at least one masked segment (or there is no contribution to the loss function). +# We choose to have: +# num_masked_segments = max(1, len(sent) // 15) +# +# The length of the masked segment (this is the target for prediction), we set to the geometric +# distribution with the probability of success set to 3: +# +# g = torch.distributions.geometric.Geometric(probs=0.3) # <-- expected value is 3.333 +# Example of sampling: +# g.sample(sample_shape=torch.Size([10])) +# +# We now we randomly compute the location of the masked segments (length computed above) as follows: +# First, the masked segments must be separated by at least one non-masked word (else they would be +# a single segment). So for n masked segments, there are n-1 words required for minimal separation. +# If tot-length-of-segments + n-1 is greater than the sentence length, we just have the entire +# sentence be masked. Otherwise, we randomly divide the remaining number of words between the n+1 +# positions where they can appear (e.g. for 2 segments, this would be at the start, between the 2 segments, +# and at the end). This is the multinomial distribution, but we can more easily compute this +# directly using rand() and cutoffs, rather than creating a torch.distributions.Multinomial(). +# + +# Next we need to compute a random amount of blank padding (>= 0) for each of the masked regions; +# this is done so the model never knows the exact length of the masked region. We can just use the +# same distribution as for the length of the masked regions, i.e. geometric with success-prob=0.3 +# (expected padding length is 3). +# +# At this point we know where the masked regions are and how much padding they have. We can format +# the result as three lists, of the same length: +# +# sent: contains the words in the sentence with, in masked +# positions, the original (target) words, then with +# blank in the blank-padding after masked positions. +# +# sent_augmented: `sent` with, at a small defined percentage of positions +# that were *not* masked, the real token replaced with a +# token randomly chosen from the tokens in the minibatch. +# (like BERT, we use this type of augmentation, so the model +# has to predict the original token). +# +# masked_sent_augmented: List[int], contains the words in `sent_augmented`, except +# with masked positions and the blank padding after the masked regions +# both replaced with blank. +# +# +# +# The way these will be processed is as follows: +# +# masked_sent_in = [bos] + masked_sent_augmented + [eos] <-- so we know the sentence ended, distinguish it from truncated ones. +# sent_in = [bos] + sent_augmented + [eos] +# +# sent_out = sent + [eos] + [eos] #<--- the predicted targets at each point, although +# # we only really care about this in masked regions. +# # The extra eos is so that the length is the same as +# # masked_sent_in and sent_in. +# +# out_scale = (masked_sent==blk ? 1.0 : non_masked_scale) # e.g. non_masked_scale = 1.0 is fine, +# # this is a choice; we can perhaps +# # report these 2 parts of the loss +# # separately though. +# # <-- can also set the last element +# # of out_scale to a smaller number, since +# # it's a repeated eos. +# +# +# OK, how do we combine these into a minibatch? Firstly, we truncate sentences to a maximum +# length, e.g. 128, if `masked_sent_in`/`sent_in` have length longer than that. We choose randomly +# in each case to truncate the beginning or end, truncating both masked_sent_in/sent_in and sent_out +# from the same side. Caution: this means that these sentences may lack bos and/or eos symbols. +# +# Next, we combine shorter utterances by appending them ( all of: masked_sent_in, sent_in, out_scale) +# as long as doing so would keep the total length under 128. We then pad (masked_sent_in, sent_in, sent_out, out_scale) +# with: (,,, 0) up to the maximum length of any sentence in the minibatch <- or could use +# +# +# +# +# +# +# +# # i.e. ones where masked_sent is blank and zeros elsewhere; +# # this pertains to positions in `sent_out`. +# +# +# +# +# +# +# +# +# +# +# torch.distributions.gamma.Gamma(concentration=1.0, rate=1.0/5) + + + + +class LmBatchSampler(torch.utils.data.Sampler): + """ + A sampler that returns a batch of integer indexes as a list, intended for use + with class LmDataset. The sentences returned in each batch will all be about + the same size, and the batch size is specified as a number of words (we also + provide an option that allows you to limit the max memory consumed by transformers) + + Has support for distributed operation. + """ + def __init__(self, dataset: LmDataset, + symbols_per_batch: int, + quadratic_constant: float = 0.005, + world_size: Optional[int] = None, + rank: int = None, + seed: int = 0): + """ + Constructor documentation: + dataset: the LmDataset object that we are sampling from. This + class does not retain a reference to the LmDataset. + symbols_per_batch: The number of BPE symbols desired in each minibatch + quadratic_constant: After the sentence length gets more than about + 1.0/quadratic_constant, the batch size will start decreasing + as 1/(sentence-length^2). This is a mechanism to + avoid excessive memory consumption in transformers, when + sentence length gets long. + world_size: The world size for distributed operation; if None, + will be worked out from torch.distributed. + rank: The rank of this sampler/process for distributed operation; if None, + will be worked out from torch.distributed. + seed: The random seed + """ + self.seed = seed + self.symbols_per_batch = symbols_per_batch + self.quadratic_constant = quadratic_constant + self._maybe_init_distributed(world_size=world_size, rank=rank) + + # a configuration constant we don't expose. + self.multiplicative_random_length = 0.05 + + # "indexes" is the subset of indexes into LmDataset that this + # sampler is reponsible for (all of them, in the non-distributed case). + data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32 + + word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32 + word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32 + + # the sentences this sampler is responsible for, as sequences of words. + # It's a ragged tensor of int32 + sentences = k2.index(dataset.sentences, data_indexes) + + # sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced + # with their respective lengths, in BPE pieces. + sentence_lengths = k2.index(word_lengths, sentences) + del sentences # save memory + assert isinstance(sentence_lengths, k2.RaggedInt) + + # convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually + # support int32.) + sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(), + sentence_lengths.values().to(torch.float32)) + assert isinstance(sentence_lengths, k2.RaggedFloat) + + # Convert into a simple tensor of float by adding lengths of words. + sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths) + + assert isinstance(sentence_lengths, torch.Tensor) + assert sentence_lengths.dtype == torch.float32 + + # self.sentence_lengths is a Tensor with dtype=torch.float32. It + # contains the lengths, in BPE tokens, of the sentences that this + # sampler is responsible for, whose real indexes are in + # `data_indexes` above (this is not stored, as we know the formula). + self.sentence_lengths = sentence_lengths + + self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes + + + def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): + if world_size is not None: + assert world_size >= 1 + if rank is not None: + assert rank >= 0 + if not dist.is_available() or not dist.is_initialized(): + self.world_size = 1 if world_size is None else world_size + self.rank = 0 if rank is None else rank + return + self.world_size = dist.get_world_size() if world_size is None else world_size + self.rank = dist.get_rank() if rank is None else rank + assert self.rank < self.world_size + + def set_epoch(self, epoch: int): + """ + Must be called at the beginning of each epoch, before initializing the DataLoader, + to re-shuffle the data. If this is not done, this sampler will give you the same batches + each time it is called. + """ + g = torch.manual_seed(self.rank + self.seed + epoch) + + sentence_lengths = (self.sentence_lengths * + (1.0 + torch.rand(*self.sentence_lengths.shape, generator=g) * self.multiplicative_random_length)) + + # This mechanism regulates the batch size so that we don't get OOM in transformers + # when the sentences are long. + sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant + + values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64 + + # map to the original indexes into the dataset (the original sentence + # indexes), see torch.arange expression in the constructor. save as + # int32 just to save a little memory. self.indices are indexes into the + # LmDataset, just including the subset of indices that this sampler is + # responsible for (in terms of rank and world_size), and sorted by + # length with a small amount of randomization specific to the epoch. + self.indices = ((indices * self.world_size) + self.rank).to(dtype=torch.int32) + + # now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ], + # saying which batch each element of values/indices belongs to. + batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) + + batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0] + batch_boundaries.add_(1) + self.batch_boundaries = torch.cat((torch.zeros(1, dtype=torch.int32), batch_boundaries), dim=0) + + num_batches = self.batch_boundaries.numel() - 1 + + # self.batch_indices is a permutation of [0, 1, ... num_batches - + # 1]; it determines the order in which we access the batches. It's + # necessary to randomize the order of these, to avoid returning batches + # from shortest to longest sentences. + self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist() + + + def __len__(self): + return len(self.batch_indices) + + def __iter__(self): + """ + Iterator that yields lists of indices (i.e., integer indices into the LmDataset) + """ + for batch_idx in self.batch_indices: + batch_start = self.batch_boundaries[batch_idx].item() + batch_end = self.batch_boundaries[batch_idx + 1].item() + yield self.indices[batch_start:batch_end].tolist() + + + + + + +# train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') +# sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0) +# a = iter(sampler) +# print(str(next(a))) + +# collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)) +# train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn) +# x = iter(train_dl) +# print(str(next(x))) diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 5c9e2a6751..e78c6d9f3e 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -46,6 +46,7 @@ def main(out_dir: str): "4-gram.arpa.gz", "librispeech-vocab.txt", "librispeech-lexicon.txt", + "librispeech-lm-norm.txt.gz" ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index f06e013f60..798a306312 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -191,4 +191,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then done fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + +fi + cd data && ln -sfv lang_bpe_5000 lang_bpe From cbe5ee1111326ab8a9f9c0359765c4c8c650c7be Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 21 Aug 2021 22:35:43 +0800 Subject: [PATCH 02/26] Copy some files, will edit.. --- egs/librispeech/ASR/conformer_lm/conformer.py | 920 ++++++++++++++++ .../ASR/conformer_lm/test_dataset.py | 13 + .../ASR/conformer_lm/test_transformer.py | 89 ++ .../ASR/conformer_lm/transformer.py | 989 ++++++++++++++++++ 4 files changed, 2011 insertions(+) create mode 100644 egs/librispeech/ASR/conformer_lm/conformer.py create mode 100644 egs/librispeech/ASR/conformer_lm/test_dataset.py create mode 100644 egs/librispeech/ASR/conformer_lm/test_transformer.py create mode 100644 egs/librispeech/ASR/conformer_lm/transformer.py diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py new file mode 100644 index 0000000000..a00664a992 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -0,0 +1,920 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + is_espnet_structure: bool = False, + mmi_loss: bool = True, + use_feat_batchnorm: bool = False, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + mmi_loss=mmi_loss, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + is_espnet_structure, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + self.is_espnet_structure = is_espnet_structure + if self.normalize_before and self.is_espnet_structure: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + if mask is not None: + mask = mask.to(x.device) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before and self.is_espnet_structure: + x = self.after_norm(x) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + is_espnet_structure: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_espnet_structure: bool = False, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + self.is_espnet_structure = is_espnet_structure + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if not self.is_espnet_structure: + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + if not self.is_espnet_structure: + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) + else: + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset.py b/egs/librispeech/ASR/conformer_lm/test_dataset.py new file mode 100644 index 0000000000..ed38ed11a2 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/test_dataset.py @@ -0,0 +1,13 @@ +import dataset +import torch + + +train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') +sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0) +a = iter(sampler) +print(str(next(a))) + +collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)) +train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn) +x = iter(train_dl) +print(str(next(x))) diff --git a/egs/librispeech/ASR/conformer_lm/test_transformer.py b/egs/librispeech/ASR/conformer_lm/test_transformer.py new file mode 100644 index 0000000000..08e6806074 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/test_transformer.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +import torch +from transformer import ( + Transformer, + encoder_padding_mask, + generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, +) + +from torch.nn.utils.rnn import pad_sequence + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py new file mode 100644 index 0000000000..51c77b2209 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/transformer.py @@ -0,0 +1,989 @@ +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from subsampling import Conv2dSubsampling, VggSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + mmi_loss: bool = True, + use_feat_batchnorm: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + mmi_loss: + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape [N, T, num_classes] + # to the shape [N, T//subsampling_factor, d_model]. + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + if mmi_loss: + self.decoder_num_class = ( + self.num_classes + 1 + ) # +1 for the sos/eos symbol + else: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) + else: + self.decoder_criterion = None + + def forward( + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is [N, T, C]. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is [N, T, C] + - Encoder output with shape [T, N, C]. It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is [N, T]. + It is None if `supervision` is None. + """ + if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape [T, N, C] + - encoder padding mask, with shape [N, T]. + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is [T, N, C] + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is [N, T, C] + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + self.pe = None + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is [N, T, C] + + Returns: + Return a tensor of shape [N, T, C] + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +class LabelSmoothingLoss(nn.Module): + """ + Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa + + Args: + size: the number of class + padding_idx: padding_idx: ignored class id + smoothing: smoothing rate (0.0 means the conventional CE) + normalize_length: normalize loss by sequence length if True + criterion: loss function to be smoothed + """ + + def __init__( + self, + size: int, + padding_idx: int = -1, + smoothing: float = 0.1, + normalize_length: bool = False, + criterion: nn.Module = nn.KLDivLoss(reduction="none"), + ) -> None: + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + assert 0.0 < smoothing <= 1.0 + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.size(2) == self.size + # batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + # denom = total if self.normalize_length else batch_size + denom = total if self.normalize_length else 1 + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans From 076a70b62dd4eb804b15b3dc87126932de12c071 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 22 Aug 2021 11:47:26 +0800 Subject: [PATCH 03/26] Initial conformer refactoring, not nearly done --- egs/librispeech/ASR/conformer_lm/conformer.py | 7 --- .../ASR/conformer_lm/transformer.py | 52 +++++++------------ 2 files changed, 20 insertions(+), 39 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index a00664a992..3014055b42 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -26,7 +26,6 @@ class Conformer(Transformer): dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. """ def __init__( @@ -42,10 +41,7 @@ def __init__( dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, - vgg_frontend: bool = False, is_espnet_structure: bool = False, - mmi_loss: bool = True, - use_feat_batchnorm: bool = False, ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -58,9 +54,6 @@ def __init__( num_decoder_layers=num_decoder_layers, dropout=dropout, normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - mmi_loss=mmi_loss, - use_feat_batchnorm=use_feat_batchnorm, ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py index 51c77b2209..707eacd1b9 100644 --- a/egs/librispeech/ASR/conformer_lm/transformer.py +++ b/egs/librispeech/ASR/conformer_lm/transformer.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -from subsampling import Conv2dSubsampling, VggSubsampling from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. @@ -18,7 +17,6 @@ def __init__( self, num_features: int, num_classes: int, - subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, @@ -26,9 +24,6 @@ def __init__( num_decoder_layers: int = 6, dropout: float = 0.1, normalize_before: bool = True, - vgg_frontend: bool = False, - mmi_loss: bool = True, - use_feat_batchnorm: bool = False, ) -> None: """ Args: @@ -54,16 +49,9 @@ def __init__( Dropout in encoder/decoder. normalize_before: If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - mmi_loss: - use_feat_batchnorm: - True to use batchnorm for the input layer. """ super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) + self.num_features = num_features self.num_classes = num_classes @@ -76,10 +64,10 @@ def __init__( # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + + #self.encoder_embed = [TODO...] + self.encoder_pos = PositionalEncoding(d_model, dropout) @@ -108,14 +96,7 @@ def __init__( ) if num_decoder_layers > 0: - if mmi_loss: - self.decoder_num_class = ( - self.num_classes + 1 - ) # +1 for the sos/eos symbol - else: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol + self.decoder_num_class = self.num_classes self.decoder_embed = nn.Embedding( num_embeddings=self.decoder_num_class, embedding_dim=d_model @@ -150,12 +131,22 @@ def __init__( self.decoder_criterion = None def forward( - self, x: torch.Tensor, supervision: Optional[Supervisions] = None + self, + src_symbols: torch.Tensor, + src_padding_mask: torch.Tensor = None ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: - The input tensor. Its shape is [N, T, C]. + src_symbols: + The input symbols to be embedded (will actually have query positions + masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. + I.e. shape (N, T) + src_padding_mask: + Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), + and dtype=torch.bool which has True in positions to be masked in attention + layers and convolutions because they represent padding at the ends of + sequences. + supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -171,10 +162,7 @@ def forward( memory_key_padding_mask for the decoder. Its shape is [N, T]. It is None if `supervision` is None. """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision ) From ea43b49ef2a9488150a08cd5769f0c2406d20269 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sun, 22 Aug 2021 11:56:22 +0800 Subject: [PATCH 04/26] Remove BatchNorm, use LayerNorm --- egs/librispeech/ASR/conformer_lm/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 3014055b42..47f36dcf34 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -864,7 +864,7 @@ def __init__( groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + self.norm = nn.LayerNorm(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, From 03ff4aab2f5b71448e728adc4ea446192fe20dee Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 11:11:09 +0800 Subject: [PATCH 05/26] Some progress on refactoring conformer code, it's in transformer.py only... --- egs/librispeech/ASR/conformer_lm/conformer.py | 1 - .../ASR/conformer_lm/transformer.py | 1440 +++++++++++------ 2 files changed, 982 insertions(+), 459 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 47f36dcf34..c727da3413 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -32,7 +32,6 @@ def __init__( self, num_features: int, num_classes: int, - subsampling_factor: int = 4, d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py index 707eacd1b9..4367808a8b 100644 --- a/egs/librispeech/ASR/conformer_lm/transformer.py +++ b/egs/librispeech/ASR/conformer_lm/transformer.py @@ -12,10 +12,9 @@ Supervisions = Dict[str, torch.Tensor] -class Transformer(nn.Module): +class MaskedLmConformer(nn.Module): def __init__( self, - num_features: int, num_classes: int, d_model: int = 256, nhead: int = 4, @@ -23,17 +22,13 @@ def __init__( num_encoder_layers: int = 12, num_decoder_layers: int = 6, dropout: float = 0.1, - normalize_before: bool = True, + cnn_module_kernel: int = 31, ) -> None: """ Args: - num_features: - The input dimension of the model. num_classes: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. + The input and output dimension of the model (inputs and outputs are + both discrete) d_model: Attention dimension. nhead: @@ -47,76 +42,45 @@ def __init__( Number of decoder layers. dropout: Dropout in encoder/decoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - """ - super().__init__() - + """ + super(MaskedLmConformer, self).__init__() - self.num_features = num_features self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - - #self.encoder_embed = [TODO...] - - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, + # self.embed is the embedding used for both the encoder and decoder. + self.embed_scale = d_model ** 0.5 + self.embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model, + _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) ) - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None + self.encoder_pos = RelPositionalEncoding(d_model, dropout) - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + encoder_layer = MaskedLmConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, ) + self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers, + norm=nn.LayerNorm(d_model)) if num_decoder_layers > 0: self.decoder_num_class = self.num_classes - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( + decoder_layer = TransformerDecoderLayerRelPos( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, - normalize_before=normalize_before, ) - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None + # Projects the embedding of `src`, to be added to `memory` + self.src_linear = torch.nn.Linear(d_model, d_model) - self.decoder = nn.TransformerDecoder( + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoderRelPos( decoder_layer=decoder_layer, num_layers=num_decoder_layers, norm=decoder_norm, @@ -126,344 +90,178 @@ def __init__( d_model, self.decoder_num_class ) - self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) - else: - self.decoder_criterion = None def forward( self, - src_symbols: torch.Tensor, - src_padding_mask: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + masked_src_symbols: torch.Tensor, + key_padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: - src_symbols: + masked_src_symbols: The input symbols to be embedded (will actually have query positions masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. I.e. shape (N, T) - src_padding_mask: + key_padding_mask: Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), and dtype=torch.bool which has True in positions to be masked in attention layers and convolutions because they represent padding at the ends of sequences. - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. - It is None if `supervision` is None. + Returns (encoded, pos_emb), where: + `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) + where C is the embedding_dim. + `pos_emb` is a Tensor containing the relative positional encoding, of + shape (1, 2*T-1, C) """ - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, supervision - ) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, x: torch.Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - Returns: - Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - x = self.encoder_embed(x) - x = self.encoder_pos(x) + x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C) + x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) - return x, mask + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is [T, N, C] + return x, pos_emb - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - def decoder_forward( + def decoder_nll( self, memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, + pos_emb: torch.Tensor, + src_symbols: torch.Tensor, + tgt_symbols: torch.Tensor, + key_padding_mask: torch.Tensor ) -> torch.Tensor: """ Args: memory: - It's the output of the encoder with shape [T, N, C] - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id + The output of the encoder, with shape (T, N, C) + pos_emb: + Relative positional embedding, of shape (1, 2*T-1, C), as + returned from the encoder + src_symbols: + The un-masked src symbols, a LongTensor of shape (N, T). + Can be used to predict the target + only in a left-to-right manner (otherwise it's cheating). + tgt_symbols: + Target symbols, a LongTensor of shape (N, T). + The same as src_symbols, but shifted by one (and also, + without symbol randomization, see randomize_proportion + in dataloader) + key_padding_mask: + A BoolTensor of shape (N, T), with True for positions + that correspond to padding at the end of source and + memory sequences. The same mask is used for self-attention + and cross-attention, since the padding is the same. Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. + Returns a tensor of shape (N, T), containing the negative + log-probabilities for the target symbols at each position + in the target sequence. """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + (T, N, C) = memory.shape - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + tgt_mask = generate_square_subsequent_mask(T, memory.device) - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) + src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) + src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, + src = memory + self.src_linear(src) # (T, N, C) + + # This is a little confusing, how "tgt" is set to src. "src" is the + # symbol sequence without masking but with padding and randomization. + # "tgt" is like "src" but shifted by one. + pred = self.decoder( + tgt=src, memory=memory, tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, + tgt_key_padding_mask=key_padding_mask, + memory_key_padding_mask=key_padding_mask, ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape [T, N, C] - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred = self.decoder_output_layer(pred) # (N, T, C) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) # nll: negative log-likelihood nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, + pred.view(-1, self.decoder_num_class), + tgt_symbols.view(-1), reduction="none", ) + nll = nll.view(N, T) + return nll - nll = nll.view(pred_pad.shape[0], -1) - return nll -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. +class TransformerDecoderRelPos(Module): + r"""TransformerDecoderRelPos is a stack of N decoder layers. + This is modified from nn.TransformerDecoder to support relative positional + encoding. Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. + decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> pos_enc = torch.rand() + >>> out = transformer_decoder(tgt, memory) """ + __constants__ = ['norm'] - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before + def forward(self, x: Tensor, + pos_emb: Tensor, + memory: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + Args: + x: the input embedding sequence to the decoder (required): shape = (T, N, C). + Will be an embedding of `src_symbols` in practice + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + representing the relative positional encoding. + memory: the sequence from the last layer of the encoder (required): + shape = (T, N, C) + attn_mask: the mask for the `x` sequence's attention to itself, + of shape (T, T); in practice, will ensure that no + position can attend to later positions. A torch.Tensor with dtype=torch.float + or dtype=torch.bool. + key_padding_mask: the key-padding mask for both the memory and x sequences, + a torch.Tensor with dtype=bool and shape (N, T): true for masked + positions after the ends of sequences. """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) + for mod in self.layers: + x = mod(x, pos_emb, memory, x_mask=x_mask, + key_padding_mask=key_padding_mask) - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) + if self.norm is not None: + output = self.norm(output) - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src + return output -class TransformerDecoderLayer(nn.Module): +class TransformerDecoderLayerRelPos(nn.Module): """ Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. + Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; + to use relative positional encoding; and for some changes/simplifications in interface + because both sequences are the same length and have the same mask. Args: d_model: @@ -479,10 +277,11 @@ class TransformerDecoderLayer(nn.Module): gelu (default=relu). Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) + >>> pos_emb = torch.rand(1, 20*2+1, 512) + >>> out = decoder_layer(tgt, pos_emb, memory) """ def __init__( @@ -492,11 +291,10 @@ def __init__( dim_feedforward: int = 2048, dropout: float = 0.1, activation: str = "relu", - normalize_before: bool = True, ) -> None: super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) @@ -511,7 +309,6 @@ def __init__( self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before def __setstate__(self, state): if "activation" not in state: @@ -520,75 +317,57 @@ def __setstate__(self, state): def forward( self, - tgt: torch.Tensor, + x: torch.Tensor, + pos_emb: torch.Tensor, memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, + x_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. Args: - tgt: - the sequence to the decoder layer (required). + x + The input embedding, to be added to by the forward function, of shape (T, N, C). + Attention within x will be left-to-right only (causal), thanks to x_mask. + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + containing the relative positional encoding. memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). + the sequence from the last layer of the encoder (required). Shape = (T, N, C) + x_mask: + the mask for the x, to enforce causal (left to right) attention (optional). + Shape == (T, T); may be bool or float. The first T pertains to the output, + the second T to the input. + key_padding_mask: + the key-padding mask to use for both the x and memory sequences. Shep == (N, T); + may be bool (True==masked) or float (to be added to attention scores). - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number + Returns: + Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float), + and shape (T, N, C). """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, + residual = x + x = self.norm1(x) + self_attn = self.self_attn(x, x, x, + key_padding_mask=key_padding_mask, + need_weights=False + attn_mask=x_mask, )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, + x = residual + self.dropout1(self_attn) + + residual = x + x = self.norm2(x) + src_attn = self.src_attn(x, memory, memory, + key_padding_mask=key_padding_mask, + need_weights=False, )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) + x = residual + self.dropout2(src_attn) - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt + residual = x + x = self.norm3(x) + ff = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = residual + self.dropout3(ff) + return x def _get_activation_fn(activation: str): @@ -831,62 +610,6 @@ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask def decoder_padding_mask( @@ -911,7 +634,7 @@ def decoder_padding_mask( return ys_mask -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: +def generate_square_subsequent_mask(sz: int, device: torch.Device) -> torch.Tensor: """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). The mask can be used for masked self-attention. @@ -928,7 +651,7 @@ def generate_square_subsequent_mask(sz: int) -> torch.Tensor: Returns: A square mask of dimension (sz, sz) """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = (torch.triu(torch.ones(sz, sz, device=torch.Device)) == 1).transpose(0, 1) mask = ( mask.float() .masked_fill(mask == 0, float("-inf")) @@ -975,3 +698,804 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: for utt in token_ids: ans.append(utt + [eos_id]) return ans + + + +class MaskedConvolutionModule(nn.Module): + """ + This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule + of theConformer code, but with key_padding_mask supported to make the output independent + of the batching. + + Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct a MaskedConvolutionModule object.""" + super(MaskedConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (T, N, C) == (#time, batch, channels). + key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and + shape (N, T), with True for positions that correspond to + padding (and should be zeroed in convolutions). + + Returns: + Tensor: Output tensor (T, N, C) + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert + # to float. Then we can just multiply by it when we need to apply + # masking, i.e. prior to the convolution over time. + if key_padding_mask is not None: + x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) # (time, batch, channel) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + + +class MaskedLmConformerEncoderLayer(nn.Module): + """ + MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution + networks. It's a simplified version of the conformer code we were previously + using, with pre-normalization hard-coded, relative positional encoding, + LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask + applied also in the convolution layers. + + See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for + the basic conformer. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + x: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + attn_mask: the mask for the x sequence's attention to itself (optional); + of shape (T, T) + key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + x: (T, N, C) i.e. (seq_len, batch_size, num_channels) + pos_emb: (N, 2*T-1, C) + attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where + the 1st S is interpreted as the target sequence (output) and the 2nd as the source + sequence (input). + key_padding_mask: (N, T), of dtype torch.bool + + T is the sequence length, N is the batch size, C is the number of channels. + Return: + Returns x with something added to it, of shape (T, N, C) + """ + + # macaron style feed forward module + residual = x + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + # multi-headed self-attention module + residual = x + x = self.norm_mha(x) + self_attn = self.self_attn(x, x, x, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False + )[0] + x = residual + self.dropout(self_attn) + + # convolution module + residual = x + x = self.norm_conv(x) + + x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask)) + + # feed forward module + residual = x + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + + return x + + +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + +class MaskedLmConformerEncoder(Module): + r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from + torch.nn.TransformerEncoder + + Args: + encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> src, pos_emb = self.encoder_pos(src) + >>> out = conformer_encoder(src, pos_emb) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None): + super(MaskedLmConformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + Args + x: input of shape (T, N, C), i.e. (seq_len, batch, channels) + pos_emb: positional embedding tensor of shape (N, 2*T-1, C), + attn_mask (optional, likely not used): mask for self-attention of + x to itself, of shape (T, T) + key_padding_mask (optional): mask of shape (N, T), dtype must be bool. + Returns: + Returns a tensor with the same shape as x, i.e. (T, N, C). + """ + for mod in self.layers: + x = mod( + x + pos_emb, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + + if self.norm is not None: + x = self.norm(x) + + return x + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (1, 2*time-1, `*`). + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: if true, return (output, attn_output_weights); else, (output, None). + + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is + the embedding dimension (number of channels). + - key: :math:`(S, N, C)`, where S is the input sequence length. + - value: :math:`(S, N, C)` + - pos_emb: :math:`(N, 2*T-1, C)`. Note: this assumes T == S, which it will be, but + still we use different letters because S relates to the input position, T to the + output posision. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length. + 3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length, + S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Return: + (output, attn_output_weights) if need_weights==True, else (output, None), where: + + - output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, + C is the embedding/channel dimension. + - attn_output_weights: :math:`(N, T, S)` where N is the batch size, + T is the output sequence length, S is the input sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + #if not self.is_espnet_structure: + # q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + #if not self.is_espnet_structure: + # attn_output_weights = ( + # matrix_ac + matrix_bd + # ) # (batch, head, time1, time2) + #else: + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None From e0b04ba54f3308b7e5a0125247b2c2ff5b0fe83e Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 15:38:37 +0800 Subject: [PATCH 06/26] Progress in testing --- egs/librispeech/ASR/conformer_lm/conformer.py | 1150 +++++++++---- .../ASR/conformer_lm/test_conformer.py | 62 + .../ASR/conformer_lm/test_transformer.py | 89 - .../ASR/conformer_lm/transformer.py | 1501 ----------------- 4 files changed, 922 insertions(+), 1880 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_lm/test_conformer.py delete mode 100644 egs/librispeech/ASR/conformer_lm/test_transformer.py delete mode 100644 egs/librispeech/ASR/conformer_lm/transformer.py diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index c727da3413..9f0db2e81b 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -1,36 +1,21 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 import math -import warnings -from typing import Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask +import torch.nn as nn +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - """ +class MaskedLmConformer(nn.Module): def __init__( self, - num_features: int, num_classes: int, d_model: int = 256, nhead: int = 4, @@ -39,81 +24,766 @@ def __init__( num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, + """ + Args: + num_classes: + The input and output dimension of the model (inputs and outputs are + both discrete) + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + """ + super(MaskedLmConformer, self).__init__() + + self.num_classes = num_classes + + # self.embed is the embedding used for both the encoder and decoder. + self.embed_scale = d_model ** 0.5 + self.embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model, + _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) - encoder_layer = ConformerEncoderLayer( + encoder_layer = MaskedLmConformerEncoderLayer( d_model, nhead, dim_feedforward, dropout, cnn_module_kernel, - normalize_before, - is_espnet_structure, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - self.is_espnet_structure = is_espnet_structure - if self.normalize_before and self.is_espnet_structure: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity + self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers, + norm=nn.LayerNorm(d_model)) - def run_encoder( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: + if num_decoder_layers > 0: + self.decoder_num_class = self.num_classes + + decoder_layer = TransformerDecoderLayerRelPos( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + # Projects the embedding of `src`, to be added to `memory` + self.src_linear = torch.nn.Linear(d_model, d_model) + + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoderRelPos( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + + def forward( + self, + masked_src_symbols: torch.Tensor, + key_padding_mask: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + masked_src_symbols: + The input symbols to be embedded (will actually have query positions + masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. + I.e. shape (N, T) + key_padding_mask: + Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), + and dtype=torch.bool which has True in positions to be masked in attention + layers and convolutions because they represent padding at the ends of + sequences. + + + Returns: + Returns (encoded, pos_emb), where: + `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) + where C is the embedding_dim. + `pos_emb` is a Tensor containing the relative positional encoding, of + shape (1, 2*T-1, C) + """ + + x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C) + x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) + + return x, pos_emb + + def decoder_nll( + self, + memory: torch.Tensor, + pos_emb: torch.Tensor, + src_symbols: torch.Tensor, + tgt_symbols: torch.Tensor, + key_padding_mask: torch.Tensor + ) -> torch.Tensor: """ + Args: + memory: + The output of the encoder, with shape (T, N, C) + pos_emb: + Relative positional embedding, of shape (1, 2*T-1, C), as + returned from the encoder + src_symbols: + The un-masked src symbols, a LongTensor of shape (N, T). + Can be used to predict the target + only in a left-to-right manner (otherwise it's cheating). + tgt_symbols: + Target symbols, a LongTensor of shape (N, T). + The same as src_symbols, but shifted by one (and also, + without symbol randomization, see randomize_proportion + in dataloader) + key_padding_mask: + A BoolTensor of shape (N, T), with True for positions + that correspond to padding at the end of source and + memory sequences. The same mask is used for self-attention + and cross-attention, since the padding is the same. + + Returns: + Returns a tensor of shape (N, T), containing the negative + log-probabilities for the target symbols at each position + in the target sequence. + """ + (T, N, C) = memory.shape + + tgt_mask = generate_square_subsequent_mask(T, memory.device) + + src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) + src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + + src = memory + self.src_linear(src) # (T, N, C) + + # This is a little confusing, how "tgt" is set to src. "src" is the + # symbol sequence without masking but with padding and randomization. + # "tgt" is like "src" but shifted by one. + pred = self.decoder( + tgt=src, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=key_padding_mask, + memory_key_padding_mask=key_padding_mask, + ) # (T, N, C) + + pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred = self.decoder_output_layer(pred) # (N, T, C) + + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred.view(-1, self.decoder_num_class), + tgt_symbols.view(-1), + reduction="none", + ) + nll = nll.view(N, T) + return nll + + + + +class TransformerDecoderRelPos(nn.Module): + r"""TransformerDecoderRelPos is a stack of N decoder layers. + This is modified from nn.TransformerDecoder to support relative positional + encoding. + + Args: + decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). + num_layers: the number of sub-decoder-layers in the decoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> pos_enc = torch.rand() + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoderRelPos, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, x: Tensor, + pos_emb: Tensor, + memory: Tensor, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the inputs (and mask) through the decoder layer in turn. + + Args: + x: the input embedding sequence to the decoder (required): shape = (T, N, C). + Will be an embedding of `src_symbols` in practice + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + representing the relative positional encoding. + memory: the sequence from the last layer of the encoder (required): + shape = (T, N, C) + attn_mask: the mask for the `x` sequence's attention to itself, + of shape (T, T); in practice, will ensure that no + position can attend to later positions. A torch.Tensor with dtype=torch.float + or dtype=torch.bool. + key_padding_mask: the key-padding mask for both the memory and x sequences, + a torch.Tensor with dtype=bool and shape (N, T): true for masked + positions after the ends of sequences. + """ + + for mod in self.layers: + x = mod(x, pos_emb, memory, x_mask=x_mask, + key_padding_mask=key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoderLayerRelPos(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; + to use relative positional encoding; and for some changes/simplifications in interface + because both sequences are the same length and have the same mask. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> pos_emb = torch.rand(1, 20*2+1, 512) + >>> out = decoder_layer(tgt, pos_emb, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + x: torch.Tensor, + pos_emb: torch.Tensor, + memory: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + x + The input embedding, to be added to by the forward function, of shape (T, N, C). + Attention within x will be left-to-right only (causal), thanks to x_mask. + pos_emb: + A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, + containing the relative positional encoding. + memory: + the sequence from the last layer of the encoder (required). Shape = (T, N, C) + x_mask: + the mask for the x, to enforce causal (left to right) attention (optional). + Shape == (T, T); may be bool or float. The first T pertains to the output, + the second T to the input. + key_padding_mask: + the key-padding mask to use for both the x and memory sequences. Shep == (N, T); + may be bool (True==masked) or float (to be added to attention scores). + + Returns: + Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float), + and shape (T, N, C). + """ + residual = x + x = self.norm1(x) + self_attn = self.self_attn(x, x, x, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=x_mask, + )[0] + x = residual + self.dropout1(self_attn) + + residual = x + x = self.norm2(x) + src_attn = self.src_attn(x, memory, memory, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + x = residual + self.dropout2(src_attn) + + residual = x + x = self.norm3(x) + ff = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = residual + self.dropout3(ff) + return x + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + self.pe = None + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is [N, T, C] + + Returns: + Return a tensor of shape [N, T, C] + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +class LabelSmoothingLoss(nn.Module): + """ + Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa + + Args: + size: the number of class + padding_idx: padding_idx: ignored class id + smoothing: smoothing rate (0.0 means the conventional CE) + normalize_length: normalize loss by sequence length if True + criterion: loss function to be smoothed + """ + + def __init__( + self, + size: int, + padding_idx: int = -1, + smoothing: float = 0.1, + normalize_length: bool = False, + criterion: nn.Module = nn.KLDivLoss(reduction="none"), + ) -> None: + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + assert 0.0 < smoothing <= 1.0 + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + Args: x: - The model input. Its shape is [N, T, C]. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) + A scalar tensor containing the loss without normalization. """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = encoder_padding_mask(x.size(0), supervisions) - if mask is not None: - mask = mask.to(x.device) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + assert x.size(2) == self.size + # batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + # denom = total if self.normalize_length else batch_size + denom = total if self.normalize_length else 1 + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom + + + +def generate_square_subsequent_mask(sz: int, device: torch.device = torch.device('cpu')) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) - if self.normalize_before and self.is_espnet_structure: - x = self.after_norm(x) + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask - return x, mask +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. -class ConformerEncoderLayer(nn.Module): + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans + + + +class MaskedConvolutionModule(nn.Module): + """ + This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule + of theConformer code, but with key_padding_mask supported to make the output independent + of the batching. + + Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: + """Construct a MaskedConvolutionModule object.""" + super(MaskedConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (T, N, C) == (#time, batch, channels). + key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and + shape (N, T), with True for positions that correspond to + padding (and should be zeroed in convolutions). + + Returns: + Tensor: Output tensor (T, N, C) + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert + # to float. Then we can just multiply by it when we need to apply + # masking, i.e. prior to the convolution over time. + if key_padding_mask is not None: + x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) # (time, batch, channel) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + + +class MaskedLmConformerEncoderLayer(nn.Module): + """ + MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution + networks. It's a simplified version of the conformer code we were previously + using, with pre-normalization hard-coded, relative positional encoding, + LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask + applied also in the convolution layers so the computation is independent of + how the sequences are batched. + + See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for + the basic conformer. Args: d_model: the number of expected features in the input (required). @@ -121,7 +791,6 @@ class ConformerEncoderLayer(nn.Module): dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -137,12 +806,10 @@ def __init__( dim_feedforward: int = 2048, dropout: float = 0.1, cnn_module_kernel: int = 31, - normalize_before: bool = True, - is_espnet_structure: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure + d_model, nhead, dropout=0.0 ) self.feed_forward = nn.Sequential( @@ -159,7 +826,7 @@ def __init__( nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel) self.norm_ff_macaron = nn.LayerNorm( d_model @@ -176,140 +843,129 @@ def __init__( self.dropout = nn.Dropout(dropout) - self.normalize_before = normalize_before - def forward( self, - src: Tensor, + x: Tensor, pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Pass the input through the encoder layer. Args: - src: the sequence to the encoder layer (required). + x: the sequence to the encoder layer (required). pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + attn_mask: the mask for the x sequence's attention to itself (optional); + of shape (T, T) + key_padding_mask: the mask for the src keys per batch (optional). Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number + x: (T, N, C) i.e. (seq_len, batch_size, num_channels) + pos_emb: (1, 2*T-1, C) + attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where + the 1st S is interpreted as the target sequence (output) and the 2nd as the source + sequence (input). + key_padding_mask: (N, T), of dtype torch.bool + + T is the sequence length, N is the batch size, C is the number of channels. + Return: + Returns x with something added to it, of shape (T, N, C) """ # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) + residual = x + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x) ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, + residual = x + x = self.norm_mha(x) + self_attn = self.self_attn(x, x, x, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) + x = residual + self.dropout(self_attn) # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) + residual = x + x = self.norm_conv(x) + + x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask)) # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) + residual = x + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) - if self.normalize_before: - src = self.norm_final(src) + x = self.norm_final(x) - return src + return x -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers +def _get_clones(module, N): + return ModuleList([copy.deepcopy(module) for i in range(N)]) + +class MaskedLmConformerEncoder(nn.Module): + r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from + torch.nn.TransformerEncoder. The only differences are some name + changes for parameters. Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). + encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) + >>> src, pos_emb = self.encoder_pos(src) >>> out = conformer_encoder(src, pos_emb) """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer: nn.Module, num_layers: int, + norm: Optional[nn.Module] = None): + super(MaskedLmConformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) def forward( self, - src: Tensor, + x: Tensor, pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - + Args + x: input of shape (T, N, C), i.e. (seq_len, batch, channels) + pos_emb: positional embedding tensor of shape (1, 2*T-1, C), + attn_mask (optional, likely not used): mask for self-attention of + x to itself, of shape (T, T) + key_padding_mask (optional): mask of shape (N, T), dtype must be bool. + Returns: + Returns a tensor with the same shape as x, i.e. (T, N, C). """ - output = src - for mod in self.layers: - output = mod( - output, + x = mod( + x, pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, ) if self.norm is not None: - output = self.norm(output) + x = self.norm(x) - return output + return x class RelPositionalEncoding(torch.nn.Module): @@ -331,7 +987,6 @@ def __init__( """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model - self.xscale = math.sqrt(self.d_model) self.dropout = torch.nn.Dropout(p=dropout_rate) self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) @@ -348,7 +1003,7 @@ def extend_pe(self, x: Tensor) -> None: ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - # Suppose `i` means to the position of query vecotr and `j` means the + # Suppose `i` means to the position of query vector and `j` means the # position of key vector. We use position relative positions when keys # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: """Add positional encoding. Args: - x (torch.Tensor): Input tensor (batch, time, `*`). + x (torch.Tensor): Input tensor (batch, time, C). - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + Returns (x, pos_enc): + x: torch.Tensor: x itself, with dropout added: (batch, time, C). + pos_enc: torch.Tensor: Relative positional encoding as tensor of shape (1, 2*time-1, C). """ self.extend_pe(x) - x = x * self.xscale pos_emb = self.pe[ :, self.pe.size(1) // 2 @@ -407,7 +1061,7 @@ class RelPositionMultiheadAttention(nn.Module): Examples:: >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + >>> attn_output, attn_output_weights = rel_pos_multihead_attn(query, key, value, pos_emb) """ def __init__( @@ -415,7 +1069,6 @@ def __init__( embed_dim: int, num_heads: int, dropout: float = 0.0, - is_espnet_structure: bool = False, ) -> None: super(RelPositionMultiheadAttention, self).__init__() self.embed_dim = embed_dim @@ -438,8 +1091,6 @@ def __init__( self._reset_parameters() - self.is_espnet_structure = is_espnet_structure - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -459,7 +1110,7 @@ def forward( attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" - Args: + Args (see below for shapes): query, key, value: map a query and a set of key-value pairs to an output. pos_emb: Positional embedding tensor key_padding_mask: if provided, specified padding elements in the key will @@ -467,37 +1118,40 @@ def forward( the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored - need_weights: output attn_output_weights. + need_weights: if true, return (output, attn_output_weights); else, (output, None). + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + - query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is + the embedding dimension (number of channels). + - key: :math:`(S, N, C)`, where S is the input sequence length. + - value: :math:`(S, N, C)` + - pos_emb: :math:`(N, 2*T-1, C)` or :math:`(1, 2*T-1, C)`. Note: this assumes T == S, which it will be, but + still we use different letters because S relates to the input position, T to the + output posision. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + - attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length. + 3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length, + S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. + Return: + (output, attn_output_weights) if need_weights==True, else (output, None), where: + + - output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, + C is the embedding/channel dimension. + - attn_output_weights: :math:`(N, T, S)` where N is the batch size, + T is the output sequence length, S is the input sequence length (actually + S and T are the same number). """ return self.multi_head_attention_forward( query, @@ -669,8 +1323,8 @@ def multi_head_attention_forward( _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if not self.is_espnet_structure: - q = q * scaling + #if not self.is_espnet_structure: + # q = q * scaling if attn_mask is not None: assert ( @@ -764,14 +1418,15 @@ def multi_head_attention_forward( ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - if not self.is_espnet_structure: - attn_output_weights = ( - matrix_ac + matrix_bd - ) # (batch, head, time1, time2) - else: - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) + #if not self.is_espnet_structure: + # attn_output_weights = ( + # matrix_ac + matrix_bd + # ) # (batch, head, time1, time2) + #else: + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, -1 @@ -825,88 +1480,3 @@ def multi_head_attention_forward( return attn_output, attn_output_weights.sum(dim=1) / num_heads else: return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - self.norm = nn.LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward(self, x: Tensor) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py new file mode 100644 index 0000000000..8aaae4277d --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# run with: +# python3 -m pytest test_conformer.py + +import torch +from conformer import ( + TransformerDecoderRelPos, + MaskedLmConformer, + RelPositionMultiheadAttention, + RelPositionalEncoding, + generate_square_subsequent_mask, +) + +from torch.nn.utils.rnn import pad_sequence + + +def test_rel_position_multihead_attention(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + + x = torch.randn(N, T, C) + #pos_emb = torch.randn(1, 2*T-1, C) + x, pos_enc = pos_emb_module(x) + print("pos_enc.shape=", pos_enc.shape) + x = x.transpose(0, 1) # (T, N, C) + attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) + + +def test_transformer(): + return + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s, torch.device('cpu')) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) diff --git a/egs/librispeech/ASR/conformer_lm/test_transformer.py b/egs/librispeech/ASR/conformer_lm/test_transformer.py deleted file mode 100644 index 08e6806074..0000000000 --- a/egs/librispeech/ASR/conformer_lm/test_transformer.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 - -import torch -from transformer import ( - Transformer, - encoder_padding_mask, - generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, -) - -from torch.nn.utils.rnn import pad_sequence - - -def test_encoder_padding_mask(): - supervisions = { - "sequence_idx": torch.tensor([0, 1, 2]), - "start_frame": torch.tensor([0, 0, 0]), - "num_frames": torch.tensor([18, 7, 13]), - } - - max_len = ((18 - 1) // 2 - 1) // 2 - mask = encoder_padding_mask(max_len, supervisions) - expected_mask = torch.tensor( - [ - [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, - [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, - [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_transformer(): - num_features = 40 - num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) - - N = 31 - - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) - - -def test_generate_square_subsequent_mask(): - s = 5 - mask = generate_square_subsequent_mask(s) - inf = float("inf") - expected_mask = torch.tensor( - [ - [0.0, -inf, -inf, -inf, -inf], - [0.0, 0.0, -inf, -inf, -inf], - [0.0, 0.0, 0.0, -inf, -inf], - [0.0, 0.0, 0.0, 0.0, -inf], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_decoder_padding_mask(): - x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] - y = pad_sequence(x, batch_first=True, padding_value=-1) - mask = decoder_padding_mask(y, ignore_id=-1) - expected_mask = torch.tensor( - [ - [False, False, True], - [False, True, True], - [False, False, False], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_add_sos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_sos(x, sos_id=0) - expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] - assert y == expected_y - - -def test_add_eos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_eos(x, eos_id=0) - expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] - assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_lm/transformer.py b/egs/librispeech/ASR/conformer_lm/transformer.py deleted file mode 100644 index 4367808a8b..0000000000 --- a/egs/librispeech/ASR/conformer_lm/transformer.py +++ /dev/null @@ -1,1501 +0,0 @@ -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# Apache 2.0 - -import math -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class MaskedLmConformer(nn.Module): - def __init__( - self, - num_classes: int, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - ) -> None: - """ - Args: - num_classes: - The input and output dimension of the model (inputs and outputs are - both discrete) - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - Number of encoder layers. - num_decoder_layers: - Number of decoder layers. - dropout: - Dropout in encoder/decoder. - """ - super(MaskedLmConformer, self).__init__() - - self.num_classes = num_classes - - # self.embed is the embedding used for both the encoder and decoder. - self.embed_scale = d_model ** 0.5 - self.embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model, - _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = MaskedLmConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - ) - self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers, - norm=nn.LayerNorm(d_model)) - - if num_decoder_layers > 0: - self.decoder_num_class = self.num_classes - - decoder_layer = TransformerDecoderLayerRelPos( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - ) - - # Projects the embedding of `src`, to be added to `memory` - self.src_linear = torch.nn.Linear(d_model, d_model) - - decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoderRelPos( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) - - - def forward( - self, - masked_src_symbols: torch.Tensor, - key_padding_mask: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - masked_src_symbols: - The input symbols to be embedded (will actually have query positions - masked), as a Tensor of shape (batch_size, seq_len) and dtype=torch.int64. - I.e. shape (N, T) - key_padding_mask: - Either None, or a Tensor of shape (batch_size, seq_len) i.e. (N, T), - and dtype=torch.bool which has True in positions to be masked in attention - layers and convolutions because they represent padding at the ends of - sequences. - - - Returns: - Returns (encoded, pos_emb), where: - `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) - where C is the embedding_dim. - `pos_emb` is a Tensor containing the relative positional encoding, of - shape (1, 2*T-1, C) - """ - - x = self.embed(masked_src_symbols) * self.embed_scale # (N, T, C) - x, pos_emb = self.encoder_pos(x) # pos_emb: (1, 2*T-1, C) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - x = self.encoder(x, pos_emb, key_padding_mask=key_padding_mask) # (T, N, C) - - return x, pos_emb - - def decoder_nll( - self, - memory: torch.Tensor, - pos_emb: torch.Tensor, - src_symbols: torch.Tensor, - tgt_symbols: torch.Tensor, - key_padding_mask: torch.Tensor - ) -> torch.Tensor: - """ - Args: - memory: - The output of the encoder, with shape (T, N, C) - pos_emb: - Relative positional embedding, of shape (1, 2*T-1, C), as - returned from the encoder - src_symbols: - The un-masked src symbols, a LongTensor of shape (N, T). - Can be used to predict the target - only in a left-to-right manner (otherwise it's cheating). - tgt_symbols: - Target symbols, a LongTensor of shape (N, T). - The same as src_symbols, but shifted by one (and also, - without symbol randomization, see randomize_proportion - in dataloader) - key_padding_mask: - A BoolTensor of shape (N, T), with True for positions - that correspond to padding at the end of source and - memory sequences. The same mask is used for self-attention - and cross-attention, since the padding is the same. - - Returns: - Returns a tensor of shape (N, T), containing the negative - log-probabilities for the target symbols at each position - in the target sequence. - """ - (T, N, C) = memory.shape - - tgt_mask = generate_square_subsequent_mask(T, memory.device) - - src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) - src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - - src = memory + self.src_linear(src) # (T, N, C) - - # This is a little confusing, how "tgt" is set to src. "src" is the - # symbol sequence without masking but with padding and randomization. - # "tgt" is like "src" but shifted by one. - pred = self.decoder( - tgt=src, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=key_padding_mask, - memory_key_padding_mask=key_padding_mask, - ) # (T, N, C) - - pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred = self.decoder_output_layer(pred) # (N, T, C) - - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred.view(-1, self.decoder_num_class), - tgt_symbols.view(-1), - reduction="none", - ) - nll = nll.view(N, T) - return nll - - - - -class TransformerDecoderRelPos(Module): - r"""TransformerDecoderRelPos is a stack of N decoder layers. - This is modified from nn.TransformerDecoder to support relative positional - encoding. - - Args: - decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). - num_layers: the number of sub-decoder-layers in the decoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) - >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> pos_enc = torch.rand() - >>> out = transformer_decoder(tgt, memory) - """ - __constants__ = ['norm'] - - def __init__(self, decoder_layer, num_layers, norm=None): - super(TransformerDecoder, self).__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward(self, x: Tensor, - pos_emb: Tensor, - memory: Tensor, - attn_mask: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None) -> Tensor: - r"""Pass the inputs (and mask) through the decoder layer in turn. - - Args: - x: the input embedding sequence to the decoder (required): shape = (T, N, C). - Will be an embedding of `src_symbols` in practice - pos_emb: - A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, - representing the relative positional encoding. - memory: the sequence from the last layer of the encoder (required): - shape = (T, N, C) - attn_mask: the mask for the `x` sequence's attention to itself, - of shape (T, T); in practice, will ensure that no - position can attend to later positions. A torch.Tensor with dtype=torch.float - or dtype=torch.bool. - key_padding_mask: the key-padding mask for both the memory and x sequences, - a torch.Tensor with dtype=bool and shape (N, T): true for masked - positions after the ends of sequences. - """ - - for mod in self.layers: - x = mod(x, pos_emb, memory, x_mask=x_mask, - key_padding_mask=key_padding_mask) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class TransformerDecoderLayerRelPos(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; - to use relative positional encoding; and for some changes/simplifications in interface - because both sequences are the same length and have the same mask. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> pos_emb = torch.rand(1, 20*2+1, 512) - >>> out = decoder_layer(tgt, pos_emb, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - x: torch.Tensor, - pos_emb: torch.Tensor, - memory: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, - key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - x - The input embedding, to be added to by the forward function, of shape (T, N, C). - Attention within x will be left-to-right only (causal), thanks to x_mask. - pos_emb: - A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, - containing the relative positional encoding. - memory: - the sequence from the last layer of the encoder (required). Shape = (T, N, C) - x_mask: - the mask for the x, to enforce causal (left to right) attention (optional). - Shape == (T, T); may be bool or float. The first T pertains to the output, - the second T to the input. - key_padding_mask: - the key-padding mask to use for both the x and memory sequences. Shep == (N, T); - may be bool (True==masked) or float (to be added to attention scores). - - Returns: - Returns 'x plus something', a torch.Tensor with dtype the same as x (e.g. float), - and shape (T, N, C). - """ - residual = x - x = self.norm1(x) - self_attn = self.self_attn(x, x, x, - key_padding_mask=key_padding_mask, - need_weights=False - attn_mask=x_mask, - )[0] - x = residual + self.dropout1(self_attn) - - residual = x - x = self.norm2(x) - src_attn = self.src_attn(x, memory, memory, - key_padding_mask=key_padding_mask, - need_weights=False, - )[0] - x = residual + self.dropout2(src_attn) - - residual = x - x = self.norm3(x) - ff = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = residual + self.dropout3(ff) - return x - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - self.pe = None - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape [N, T, C]. - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is [N, T, C] - - Returns: - Return a tensor of shape [N, T, C] - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -class LabelSmoothingLoss(nn.Module): - """ - Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa - - Args: - size: the number of class - padding_idx: padding_idx: ignored class id - smoothing: smoothing rate (0.0 means the conventional CE) - normalize_length: normalize loss by sequence length if True - criterion: loss function to be smoothed - """ - - def __init__( - self, - size: int, - padding_idx: int = -1, - smoothing: float = 0.1, - normalize_length: bool = False, - criterion: nn.Module = nn.KLDivLoss(reduction="none"), - ) -> None: - """Construct an LabelSmoothingLoss object.""" - super(LabelSmoothingLoss, self).__init__() - self.criterion = criterion - self.padding_idx = padding_idx - assert 0.0 < smoothing <= 1.0 - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.size = size - self.true_dist = None - self.normalize_length = normalize_length - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.padding_id of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.size(2) == self.size - # batch_size = x.size(0) - x = x.view(-1, self.size) - target = target.view(-1) - with torch.no_grad(): - true_dist = x.clone() - true_dist.fill_(self.smoothing / (self.size - 1)) - ignore = target == self.padding_idx # (B,) - total = len(target) - ignore.sum().item() - target = target.masked_fill(ignore, 0) # avoid -1 index - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) - # denom = total if self.normalize_length else batch_size - denom = total if self.normalize_length else 1 - return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom - - - - -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int, device: torch.Device) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz, device=torch.Device)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - ans = [] - for utt in token_ids: - ans.append([sos_id] + utt) - return ans - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - ans = [] - for utt in token_ids: - ans.append(utt + [eos_id]) - return ans - - - -class MaskedConvolutionModule(nn.Module): - """ - This is used in the MaskedLmConformerLayer. It is the same as the ConvolutionModule - of theConformer code, but with key_padding_mask supported to make the output independent - of the batching. - - Modified, ultimately, from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - """ - - def __init__( - self, channels: int, kernel_size: int, bias: bool = True - ) -> None: - """Construct a MaskedConvolutionModule object.""" - super(MaskedConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=channels, - bias=bias, - ) - self.norm = nn.LayerNorm(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (T, N, C) == (#time, batch, channels). - key_padding_mask: if supplied, a Tensor with dtype=torch.Bool and - shape (N, T), with True for positions that correspond to - padding (and should be zeroed in convolutions). - - Returns: - Tensor: Output tensor (T, N, C) - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # Logical-not key_padding_mask, unsqueeze to shape (N, 1, T) and convert - # to float. Then we can just multiply by it when we need to apply - # masking, i.e. prior to the convolution over time. - if key_padding_mask is not None: - x = x * torch.logical_not(key_padding_mask).unsqueeze(1).to(dtype=x.dtype) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) # (time, batch, channel) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - - -class MaskedLmConformerEncoderLayer(nn.Module): - """ - MaskedLmConformerEncoderLayer is made up of self-attn, feedforward and convolution - networks. It's a simplified version of the conformer code we were previously - using, with pre-normalization hard-coded, relative positional encoding, - LayerNorm instead of BatchNorm in the convolution layers, and the key_padding_mask - applied also in the convolution layers. - - See: "Conformer: Convolution-augmented Transformer for Speech Recognition", for - the basic conformer. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = MaskedConvolutionModule(d_model, cnn_module_kernel) - - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - x: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - attn_mask: the mask for the x sequence's attention to itself (optional); - of shape (T, T) - key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - x: (T, N, C) i.e. (seq_len, batch_size, num_channels) - pos_emb: (N, 2*T-1, C) - attn_mask: (T, T) or (N*num_heads, T, T), of dtype torch.bool or torch.float, where - the 1st S is interpreted as the target sequence (output) and the 2nd as the source - sequence (input). - key_padding_mask: (N, T), of dtype torch.bool - - T is the sequence length, N is the batch size, C is the number of channels. - Return: - Returns x with something added to it, of shape (T, N, C) - """ - - # macaron style feed forward module - residual = x - x = self.norm_ff_macaron(x) - x = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(x) - ) - - # multi-headed self-attention module - residual = x - x = self.norm_mha(x) - self_attn = self.self_attn(x, x, x, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False - )[0] - x = residual + self.dropout(self_attn) - - # convolution module - residual = x - x = self.norm_conv(x) - - x = residual + self.dropout(self.conv_module(x, key_padding_mask=key_padding_mask)) - - # feed forward module - residual = x - x = self.norm_ff(x) - x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) - - x = self.norm_final(x) - - return x - - -def _get_clones(module, N): - return ModuleList([copy.deepcopy(module) for i in range(N)]) - -class MaskedLmConformerEncoder(Module): - r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from - torch.nn.TransformerEncoder - - Args: - encoder_layer: an instance of the MaskedLmConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = MaskedLmConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> src, pos_emb = self.encoder_pos(src) - >>> out = conformer_encoder(src, pos_emb) - """ - __constants__ = ['norm'] - - def __init__(self, encoder_layer: nn.Module, num_layers: int, - norm: Optional[nn.Module] = None): - super(MaskedLmConformerEncoder, self).__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - Args - x: input of shape (T, N, C), i.e. (seq_len, batch, channels) - pos_emb: positional embedding tensor of shape (N, 2*T-1, C), - attn_mask (optional, likely not used): mask for self-attention of - x to itself, of shape (T, T) - key_padding_mask (optional): mask of shape (N, T), dtype must be bool. - Returns: - Returns a tensor with the same shape as x, i.e. (T, N, C). - """ - for mod in self.layers: - x = mod( - x - pos_emb, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - ) - - if self.norm is not None: - x = self.norm(x) - - return x - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (1, 2*time-1, `*`). - - """ - self.extend_pe(x) - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x.size(1) - + 1 : self.pe.size(1) // 2 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: if true, return (output, attn_output_weights); else, (output, None). - - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, C is - the embedding dimension (number of channels). - - key: :math:`(S, N, C)`, where S is the input sequence length. - - value: :math:`(S, N, C)` - - pos_emb: :math:`(N, 2*T-1, C)`. Note: this assumes T == S, which it will be, but - still we use different letters because S relates to the input position, T to the - output posision. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the input sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(T, S)` where T is the output sequence length, S is the input sequence length. - 3D mask :math:`(N*num_heads, T, S)` where N is the batch size, where T is the output sequence length, - S is the input sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Return: - (output, attn_output_weights) if need_weights==True, else (output, None), where: - - - output: :math:`(T, N, C)` where T is the output sequence length, N is the batch size, - C is the embedding/channel dimension. - - attn_output_weights: :math:`(N, T, S)` where N is the batch size, - T is the output sequence length, S is the input sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - ) - - def rel_shift(self, x: Tensor) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - assert n == 2 * time1 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time1), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - #if not self.is_espnet_structure: - # q = q * scaling - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p.transpose(-2, -1) - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd) - - #if not self.is_espnet_structure: - # attn_output_weights = ( - # matrix_ac + matrix_bd - # ) # (batch, head, time1, time2) - #else: - - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None From 2fbe3b78fd831d82600d0d832eeae7f8556b5ea3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 17:18:00 +0800 Subject: [PATCH 07/26] Add more testing; fix issue about channel dim of LayerNorm. --- egs/librispeech/ASR/conformer_lm/conformer.py | 8 +++++--- .../ASR/conformer_lm/test_conformer.py | 20 ++++++++++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 9f0db2e81b..35d7001195 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -757,8 +757,10 @@ def forward(self, x: Tensor, key_padding_mask: Optional[Tensor]) -> Tensor: # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - + x = x.transpose(1, 2) # (batch, time, channel) + x = self.norm(x) # LayerNorm requires channel be last dim. + x = x.transpose(1, 2) # (batch, channel, time) + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) return x.permute(2, 0, 1) # (time, batch, channel) @@ -807,7 +809,7 @@ def __init__( dropout: float = 0.1, cnn_module_kernel: int = 31, ) -> None: - super(ConformerEncoderLayer, self).__init__() + super(MaskedLmConformerEncoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention( d_model, nhead, dropout=0.0 ) diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 8aaae4277d..3fdd8a2221 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -6,6 +6,7 @@ from conformer import ( TransformerDecoderRelPos, MaskedLmConformer, + MaskedLmConformerEncoderLayer, RelPositionMultiheadAttention, RelPositionalEncoding, generate_square_subsequent_mask, @@ -27,11 +28,28 @@ def test_rel_position_multihead_attention(): x = torch.randn(N, T, C) #pos_emb = torch.randn(1, 2*T-1, C) x, pos_enc = pos_emb_module(x) - print("pos_enc.shape=", pos_enc.shape) x = x.transpose(0, 1) # (T, N, C) attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) +def test_masked_lm_conformer_encoder_layer(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads) + + + x = torch.randn(N, T, C) + x, pos_enc = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) + + def test_transformer(): return num_features = 40 From 556fae586fdaf078253e7213e2fd68acf9b74ae4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 17:22:03 +0800 Subject: [PATCH 08/26] Add testing for MaskedLmConformerEncoder --- egs/librispeech/ASR/conformer_lm/conformer.py | 3 ++- .../ASR/conformer_lm/test_conformer.py | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 35d7001195..163f475434 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -1,6 +1,7 @@ # Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 +import copy import math from typing import Dict, List, Optional, Tuple @@ -910,7 +911,7 @@ def forward( def _get_clones(module, N): - return ModuleList([copy.deepcopy(module) for i in range(N)]) + return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)]) class MaskedLmConformerEncoder(nn.Module): r"""MaskedLmConformerEncoder is a stack of N encoder layers, modified from diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 3fdd8a2221..8c2b2efa42 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -6,6 +6,7 @@ from conformer import ( TransformerDecoderRelPos, MaskedLmConformer, + MaskedLmConformerEncoder, MaskedLmConformerEncoderLayer, RelPositionMultiheadAttention, RelPositionalEncoding, @@ -50,6 +51,28 @@ def test_masked_lm_conformer_encoder_layer(): y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) +def test_masked_lm_conformer_encoder(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + encoder_layer = MaskedLmConformerEncoderLayer(embed_dim, num_heads) + norm = torch.nn.LayerNorm(embed_dim) + encoder = MaskedLmConformerEncoder(encoder_layer, num_layers=4, + norm=norm) + + + x = torch.randn(N, T, C) + x, pos_enc = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + y = encoder(x, pos_enc, key_padding_mask=key_padding_mask) + + + def test_transformer(): return num_features = 40 From 7856ab89fc011bb26e05e2a0371545137eaa6266 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 17:39:37 +0800 Subject: [PATCH 09/26] Test, and fix, TransformerDecoderLayerRelPos --- egs/librispeech/ASR/conformer_lm/conformer.py | 16 +++++---- .../ASR/conformer_lm/test_conformer.py | 34 +++++++++++++++---- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 163f475434..e158e88d5b 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -249,7 +249,7 @@ def forward(self, x: Tensor, """ for mod in self.layers: - x = mod(x, pos_emb, memory, x_mask=x_mask, + x = mod(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) if self.norm is not None: @@ -294,7 +294,7 @@ def __init__( dropout: float = 0.1, activation: str = "relu", ) -> None: - super(TransformerDecoderLayer, self).__init__() + super(TransformerDecoderLayerRelPos, self).__init__() self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model @@ -315,14 +315,14 @@ def __init__( def __setstate__(self, state): if "activation" not in state: state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) + super(TransformerDecoderLayerRelPos, self).__setstate__(state) def forward( self, x: torch.Tensor, pos_emb: torch.Tensor, memory: torch.Tensor, - x_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. @@ -330,13 +330,13 @@ def forward( Args: x The input embedding, to be added to by the forward function, of shape (T, N, C). - Attention within x will be left-to-right only (causal), thanks to x_mask. + Attention within x will be left-to-right only (causal), thanks to attn_mask. pos_emb: A torch.Tensor with dtype=torch.float and shape (1, 2*T-1, C) with c==num_channels, containing the relative positional encoding. memory: the sequence from the last layer of the encoder (required). Shape = (T, N, C) - x_mask: + attn_mask: the mask for the x, to enforce causal (left to right) attention (optional). Shape == (T, T); may be bool or float. The first T pertains to the output, the second T to the input. @@ -351,15 +351,17 @@ def forward( residual = x x = self.norm1(x) self_attn = self.self_attn(x, x, x, + pos_emb=pos_emb, key_padding_mask=key_padding_mask, need_weights=False, - attn_mask=x_mask, + attn_mask=attn_mask, )[0] x = residual + self.dropout1(self_attn) residual = x x = self.norm2(x) src_attn = self.src_attn(x, memory, memory, + pos_emb=pos_emb, key_padding_mask=key_padding_mask, need_weights=False, )[0] diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 8c2b2efa42..106b847380 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -5,6 +5,7 @@ import torch from conformer import ( TransformerDecoderRelPos, + TransformerDecoderLayerRelPos, MaskedLmConformer, MaskedLmConformerEncoder, MaskedLmConformerEncoderLayer, @@ -28,9 +29,9 @@ def test_rel_position_multihead_attention(): x = torch.randn(N, T, C) #pos_emb = torch.randn(1, 2*T-1, C) - x, pos_enc = pos_emb_module(x) + x, pos_emb = pos_emb_module(x) x = x.transpose(0, 1) # (T, N, C) - attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_enc) + attn_output, attn_output_weights = rel_pos_multihead_attn(x, x, x, pos_emb) def test_masked_lm_conformer_encoder_layer(): @@ -45,10 +46,10 @@ def test_masked_lm_conformer_encoder_layer(): x = torch.randn(N, T, C) - x, pos_enc = pos_emb_module(x) + x, pos_emb = pos_emb_module(x) x = x.transpose(0, 1) # (T, N, C) key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) - y = encoder_layer(x, pos_enc, key_padding_mask=key_padding_mask) + y = encoder_layer(x, pos_emb, key_padding_mask=key_padding_mask) def test_masked_lm_conformer_encoder(): @@ -66,10 +67,31 @@ def test_masked_lm_conformer_encoder(): x = torch.randn(N, T, C) - x, pos_enc = pos_emb_module(x) + x, pos_emb = pos_emb_module(x) x = x.transpose(0, 1) # (T, N, C) key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) - y = encoder(x, pos_enc, key_padding_mask=key_padding_mask) + y = encoder(x, pos_emb, key_padding_mask=key_padding_mask) + + +def test_transformer_decoder_layer_rel_pos(): + # Also tests RelPositionalEncoding + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + + + x = torch.randn(N, T, C) + x, pos_emb = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + attn_mask = generate_square_subsequent_mask(T) + memory = torch.randn(T, N, C) + y = decoder_layer(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + From 5fecd246643b354d20362a87f9b05a25a5777ba3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 17:48:00 +0800 Subject: [PATCH 10/26] Test, and fix, TransformerDecoderRelPos --- egs/librispeech/ASR/conformer_lm/conformer.py | 17 +++++++-------- .../ASR/conformer_lm/test_conformer.py | 21 ++++++++++++++++++- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index e158e88d5b..6207dab84d 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -166,24 +166,24 @@ def decoder_nll( tgt_mask = generate_square_subsequent_mask(T, memory.device) - src = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) - src = src.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - + x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - src = memory + self.src_linear(src) # (T, N, C) + x = memory + self.src_linear(x) # (T, N, C) # This is a little confusing, how "tgt" is set to src. "src" is the # symbol sequence without masking but with padding and randomization. # "tgt" is like "src" but shifted by one. pred = self.decoder( - tgt=src, + x, + pos_emb, memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=key_padding_mask, memory_key_padding_mask=key_padding_mask, ) # (T, N, C) - pred = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C) pred = self.decoder_output_layer(pred) # (N, T, C) # nll: negative log-likelihood @@ -247,15 +247,14 @@ def forward(self, x: Tensor, a torch.Tensor with dtype=bool and shape (N, T): true for masked positions after the ends of sequences. """ - for mod in self.layers: x = mod(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) if self.norm is not None: - output = self.norm(output) + x = self.norm(x) - return output + return x class TransformerDecoderLayerRelPos(nn.Module): diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 106b847380..99acfdcd08 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -74,7 +74,6 @@ def test_masked_lm_conformer_encoder(): def test_transformer_decoder_layer_rel_pos(): - # Also tests RelPositionalEncoding embed_dim = 256 num_heads = 4 T = 25 @@ -94,6 +93,26 @@ def test_transformer_decoder_layer_rel_pos(): +def test_transformer_decoder_rel_pos(): + embed_dim = 256 + num_heads = 4 + T = 25 + N = 4 + C = 256 + pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) + decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + decoder_norm = torch.nn.LayerNorm(embed_dim) + decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm) + + + x = torch.randn(N, T, C) + x, pos_emb = pos_emb_module(x) + x = x.transpose(0, 1) # (T, N, C) + key_padding_mask = (torch.randn(N, T) > 0.0) # (N, T) + attn_mask = generate_square_subsequent_mask(T) + memory = torch.randn(T, N, C) + y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + def test_transformer(): return From 26b5b5ba469d5b4f652941b9680ddd1fbff0d494 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 19:05:31 +0800 Subject: [PATCH 11/26] Get tests to work for MaskedLmConformer --- egs/librispeech/ASR/conformer_lm/conformer.py | 48 +++++++++---------- egs/librispeech/ASR/conformer_lm/dataset.py | 26 +++++----- .../ASR/conformer_lm/test_conformer.py | 40 ++++++++++------ 3 files changed, 62 insertions(+), 52 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 6207dab84d..1963056cc9 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -52,8 +52,8 @@ def __init__( # self.embed is the embedding used for both the encoder and decoder. self.embed_scale = d_model ** 0.5 self.embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model, - _weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.embed_scale) + num_embeddings=self.num_classes, embedding_dim=d_model, + _weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale) ) self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -69,9 +69,8 @@ def __init__( norm=nn.LayerNorm(d_model)) if num_decoder_layers > 0: - self.decoder_num_class = self.num_classes - decoder_layer = TransformerDecoderLayerRelPos( + decoder_layer = RelPosTransformerDecoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, @@ -82,14 +81,14 @@ def __init__( self.src_linear = torch.nn.Linear(d_model, d_model) decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoderRelPos( + self.decoder = RelPosTransformerDecoder( decoder_layer=decoder_layer, num_layers=num_decoder_layers, norm=decoder_norm, ) self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class + d_model, self.num_classes ) @@ -112,8 +111,8 @@ def forward( Returns: - Returns (encoded, pos_emb), where: - `encoded` is a Tensor containing the encoded data; it is of shape (N, T, C) + Returns (memory, pos_emb), where: + `memory` is a Tensor containing the encoded data; it is of shape (N, T, C) where C is the embedding_dim. `pos_emb` is a Tensor containing the relative positional encoding, of shape (1, 2*T-1, C) @@ -164,7 +163,7 @@ def decoder_nll( """ (T, N, C) = memory.shape - tgt_mask = generate_square_subsequent_mask(T, memory.device) + attn_mask = generate_square_subsequent_mask(T, memory.device) x = self.embed(src_symbols) * self.embed_scale # (N, T) -> (N, T, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) @@ -178,18 +177,17 @@ def decoder_nll( x, pos_emb, memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=key_padding_mask, - memory_key_padding_mask=key_padding_mask, - ) # (T, N, C) + attn_mask=attn_mask, + key_padding_mask=key_padding_mask) + # (T, N, C) pred = pred.permute(1, 0, 2) # (T, N, C) -> (N, T, C) pred = self.decoder_output_layer(pred) # (N, T, C) # nll: negative log-likelihood nll = torch.nn.functional.cross_entropy( - pred.view(-1, self.decoder_num_class), - tgt_symbols.view(-1), + pred.view(-1, self.num_classes), + tgt_symbols.reshape(-1), reduction="none", ) nll = nll.view(N, T) @@ -198,19 +196,19 @@ def decoder_nll( -class TransformerDecoderRelPos(nn.Module): - r"""TransformerDecoderRelPos is a stack of N decoder layers. +class RelPosTransformerDecoder(nn.Module): + r"""RelPosTransformerDecoder is a stack of N decoder layers. This is modified from nn.TransformerDecoder to support relative positional encoding. Args: - decoder_layer: an instance of the TransformerDecoderLayerRelPos() class (required). + decoder_layer: an instance of the RelPosTransformerDecoderLayer() class (required). num_layers: the number of sub-decoder-layers in the decoder (required). norm: the layer normalization component (optional). Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) - >>> transformer_decoder = nn.TransformerDecoderRelPos(decoder_layer, num_layers=6) + >>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.RelPosTransformerDecoder(decoder_layer, num_layers=6) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> pos_enc = torch.rand() @@ -219,7 +217,7 @@ class TransformerDecoderRelPos(nn.Module): __constants__ = ['norm'] def __init__(self, decoder_layer, num_layers, norm=None): - super(TransformerDecoderRelPos, self).__init__() + super(RelPosTransformerDecoder, self).__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm @@ -257,7 +255,7 @@ def forward(self, x: Tensor, return x -class TransformerDecoderLayerRelPos(nn.Module): +class RelPosTransformerDecoderLayer(nn.Module): """ Modified from torch.nn.TransformerDecoderLayer. Add it to use normalize_before (hardcoded to True), i.e. use layer_norm before the first block; @@ -278,7 +276,7 @@ class TransformerDecoderLayerRelPos(nn.Module): gelu (default=relu). Examples:: - >>> decoder_layer = nn.TransformerDecoderLayerRelPos(d_model=512, nhead=8) + >>> decoder_layer = nn.RelPosTransformerDecoderLayer(d_model=512, nhead=8) >>> memory = torch.rand(10, 32, 512) >>> tgt = torch.rand(20, 32, 512) >>> pos_emb = torch.rand(1, 20*2+1, 512) @@ -293,7 +291,7 @@ def __init__( dropout: float = 0.1, activation: str = "relu", ) -> None: - super(TransformerDecoderLayerRelPos, self).__init__() + super(RelPosTransformerDecoderLayer, self).__init__() self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) self.src_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) # Implementation of Feedforward model @@ -314,7 +312,7 @@ def __init__( def __setstate__(self, state): if "activation" not in state: state["activation"] = nn.functional.relu - super(TransformerDecoderLayerRelPos, self).__setstate__(state) + super(RelPosTransformerDecoderLayer, self).__setstate__(state) def forward( self, diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 75f603d9d4..cad2c4d8f5 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -297,7 +297,7 @@ def collate_fn(sentences: List[List[int]], Will be reflected in the returned tgt_weights tensor. Returns a tuple (masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, + tgt_symbols, src_key_padding_mask, tgt_weights), all with 2 axes and the same shape: (num_sent, seq_len). Their dtypes will be, respectively, @@ -315,7 +315,7 @@ def collate_fn(sentences: List[List[int]], tgt_symbols: The original sentences, with eos_symbol appended, and then padded with blank to the same length as masked_symbols and src_symbols. - src_attn_mask: Masking tensor for masked_src_symbols and src_symbols, to + src_key_padding_mask: Masking tensor for masked_src_symbols and src_symbols, to account for all the sentence lengths not being identical (makes each sentence's processing independent of seq_len). Tensor of Bool of shape (num_sent, seq_len), with True @@ -368,17 +368,17 @@ def collate_fn(sentences: List[List[int]], src_symbols = torch.tensor(srcs, dtype=torch.int64) masked_src_symbols = torch.tensor(srcs_masked, dtype=torch.int64) tgt_symbols = torch.tensor(tgts, dtype=torch.int64) - src_attn_mask = torch.tensor(attn_masks, dtype=torch.bool) + src_key_padding_mask = torch.tensor(attn_masks, dtype=torch.bool) tgt_weights = torch.tensor(weights, dtype=torch.float) - attn_mask_sum = torch.sum(torch.logical_not(src_attn_mask), dim=0).tolist() + attn_mask_sum = torch.sum(torch.logical_not(src_key_padding_mask), dim=0).tolist() while attn_mask_sum[-1] == 0: # Remove always-masked positions at the endof the lists. attn_mask_sum.pop() if len(attn_mask_sum) < seq_len: seq_len = len(attn_mask_sum) (src_symbols, masked_src_symbols, - tgt_symbols, src_attn_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len], - tgt_symbols[:,:seq_len], src_attn_mask[:,:seq_len], + tgt_symbols, src_key_padding_mask, tgt_weights) = (src_symbols[:,:seq_len], masked_src_symbols[:,:seq_len], + tgt_symbols[:,:seq_len], src_key_padding_mask[:,:seq_len], tgt_weights[:,:seq_len]) if randomize_proportion > 0.0: @@ -409,9 +409,9 @@ def tensor_split(t): check_collated_tensors(sentences, bos_sym, eos_sym, blank_sym, unmasked_weight, masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, tgt_weights) + tgt_symbols, src_key_padding_mask, tgt_weights) return (masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, tgt_weights) + tgt_symbols, src_key_padding_mask, tgt_weights) @@ -421,20 +421,20 @@ def check_collated_tensors(sentences: List[List[int]], blank_sym: int, unmasked_weight: float, masked_src_symbols, src_symbols, - tgt_symbols, src_attn_mask, + tgt_symbols, src_key_padding_mask, tgt_weights): """ This function checks the output of collate_fn, consider it test code. Please see the documentation of collate_fn to understand the args. """ - for t in src_symbols, tgt_symbols, src_attn_mask, tgt_weights: + for t in src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights: assert t.shape == masked_src_symbols.shape tot_positions = src_symbols.numel() - masked_src_symbols, src_symbols, tgt_symbols, src_attn_mask, tgt_weights = ( + masked_src_symbols, src_symbols, tgt_symbols, src_key_padding_mask, tgt_weights = ( masked_src_symbols.tolist(), src_symbols.tolist(), tgt_symbols.tolist(), - src_attn_mask.tolist(), tgt_weights.tolist()) + src_key_padding_mask.tolist(), tgt_weights.tolist()) assert len(sentences) == len(masked_src_symbols) tot_masked_positions = 0 @@ -451,7 +451,7 @@ def check_collated_tensors(sentences: List[List[int]], if sentences[i] != reconstructed_sent: print(f"Error: sentence {i}={sentences[i]} differs from {reconstructed_sent}") (masked_src, src, tgt, src_mask, weights) = (masked_src_symbols[i], src_symbols[i], - tgt_symbols[i], src_attn_mask[i], tgt_weights[i]) + tgt_symbols[i], src_key_padding_mask[i], tgt_weights[i]) assert src[0] == masked_src[0] == bos_sym for j in range(len(masked_src)): diff --git a/egs/librispeech/ASR/conformer_lm/test_conformer.py b/egs/librispeech/ASR/conformer_lm/test_conformer.py index 99acfdcd08..45b50a6ea6 100644 --- a/egs/librispeech/ASR/conformer_lm/test_conformer.py +++ b/egs/librispeech/ASR/conformer_lm/test_conformer.py @@ -3,9 +3,10 @@ # python3 -m pytest test_conformer.py import torch +import dataset # from . from conformer import ( - TransformerDecoderRelPos, - TransformerDecoderLayerRelPos, + RelPosTransformerDecoder, + RelPosTransformerDecoderLayer, MaskedLmConformer, MaskedLmConformerEncoder, MaskedLmConformerEncoderLayer, @@ -80,7 +81,7 @@ def test_transformer_decoder_layer_rel_pos(): N = 4 C = 256 pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) - decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads) x = torch.randn(N, T, C) @@ -100,10 +101,9 @@ def test_transformer_decoder_rel_pos(): N = 4 C = 256 pos_emb_module = RelPositionalEncoding(C, dropout_rate=0.0) - decoder_layer = TransformerDecoderLayerRelPos(embed_dim, num_heads) + decoder_layer = RelPosTransformerDecoderLayer(embed_dim, num_heads) decoder_norm = torch.nn.LayerNorm(embed_dim) - decoder = TransformerDecoderRelPos(decoder_layer, num_layers=6, norm=decoder_norm) - + decoder = RelPosTransformerDecoder(decoder_layer, num_layers=6, norm=decoder_norm) x = torch.randn(N, T, C) x, pos_emb = pos_emb_module(x) @@ -114,18 +114,30 @@ def test_transformer_decoder_rel_pos(): y = decoder(x, pos_emb, memory, attn_mask=attn_mask, key_padding_mask=key_padding_mask) -def test_transformer(): - return - num_features = 40 +def test_masked_lm_conformer(): + num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) + d_model = 256 + + model = MaskedLmConformer(num_classes,d_model) + N = 31 - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + (masked_src_symbols, src_symbols, + tgt_symbols, src_key_padding_mask, + tgt_weights) = dataset.collate_fn(sentences=[ list(range(10, 20)), list(range(30, 45)), list(range(50,68))], bos_sym=1, eos_sym=2, + blank_sym=0) + + # test forward() of MaskedLmConformer + memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) + nll = model.decoder_nll(memory, pos_emb, src_symbols, tgt_symbols, + src_key_padding_mask) + print("nll = ", nll) + loss = (nll * tgt_weights).sum() + print("loss = ", loss) + def test_generate_square_subsequent_mask(): From 894be068e78496c648e2b1314ad4dd5f65a665e3 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 19:51:58 +0800 Subject: [PATCH 12/26] Update prepare.sh to create LM training data; add missed scripts local/prepare_lm_training_data.py --- .../ASR/local/prepare_lm_training_data.py | 118 ++++++++++++++++++ egs/librispeech/ASR/prepare.sh | 8 +- 2 files changed, 125 insertions(+), 1 deletion(-) create mode 100755 egs/librispeech/ASR/local/prepare_lm_training_data.py diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py new file mode 100755 index 0000000000..b6e0931f40 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang, Daniel Povey) + +""" + +This script takes a `bpe.model` and a text file such as `download/lm/librispeech-lm-norm.txt`, +and outputs the LM training data to a supplied directory such +as data/lm_training_data_bpe_5000. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a representation of +a dict with the following format: + + 'words' -> a k2._RaggedInt containing the BPE representations of each word, inexed by + integer word ID. (These integer word IDS are present in 'lm_data'). The + sentencepiece object can be used to turn the words and BPE units into + string form. + 'data' -> a k2._RaggedInt containing all the sentences, as word-ids (we don't output + the string form of this directly but it can be worked out together with + 'words' and the bpe.model). + +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch + + + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "bpe_model", + type=str, + help="""Input BPE model, e.g. data/lang_bpe/bpe.model""" + ) + parser.add_argument( + "lm_data", + type=str, + help="""Input LM training data as text, e.g. data/downloads/lm/librispeech-lm-norm.txt""" + ) + parser.add_argument( + "lm_archive", + type=str, + help="""Path to output archive, e.g. lm_data.pt; look at the source of this script to see the format.""" + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + words2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + + sentences = [] # Wil be a list-of-list-of-int, representing word-ids. + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == '': + break + line_words = line.split() + for w in line_words: + if not w in word2index: + w_bpe = sp.Encode(w) + word2index[w] = len(words2bpe) + words2bpe.append(w_bpe) + sentences.append([ word2index[w] for w in line_words]) + + output = dict() + output['words' ] = k2.ragged.create_ragged2(words2bpe) + output['data'] = k2.ragged.create_ragged2(sentences) + + torch.save(output, args.lm_archive) + print(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + main() + + + +# This was tested as follows. +# cat > foo <>> import k2 +#>>> import sentencepiece as spm +#>>> sp = spm.SentencePieceProcessor() +#>>> sp.load('data/lang_bpe/bpe.model') +#True +#>>> import torch +#>>> d = torch.load('bar.pt') +#>>> sp.Decode(k2.ragged.to_list(k2.index(d['words'], d['data']))) +#['THING TWO', 'ZOOLOGY'] +#>>> diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 798a306312..94c408c6e4 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -193,7 +193,13 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + lm_dir=lm_dir=data/lm_training_${vocab_size} + mkdir -p $lm_dir + log "Stage 9: creating $lm_dir/lm_data.pt" + ./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt + done fi cd data && ln -sfv lang_bpe_5000 lang_bpe From c3a87274465953971fc00c03e46205f9cec490f2 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 22:28:45 +0800 Subject: [PATCH 13/26] Add train.py --- egs/librispeech/ASR/conformer_lm/dataset.py | 5 +- egs/librispeech/ASR/conformer_lm/madam.py | 959 ++++++++++++++++++++ egs/librispeech/ASR/conformer_lm/train.py | 607 +++++++++++++ 3 files changed, 1569 insertions(+), 2 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_lm/madam.py create mode 100755 egs/librispeech/ASR/conformer_lm/train.py diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index cad2c4d8f5..3074d10999 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -3,7 +3,8 @@ import k2 import _k2 import sentencepiece as spm -from typing import Optional, List, Tuple +from pathlib import Path +from typing import Optional, List, Tuple, Union @@ -36,7 +37,7 @@ def __getitem__(self, i: int): return k2.index(self.words, sentence).values().tolist() -def load_train_test_lm_dataset(archive_fn: str, +def load_train_test_lm_dataset(archive_fn: Union[str,Path], test_proportion: float = 0.025) -> Tuple[LmDataset, LmDataset]: """ returns (train_lm_dataset, test_lm_dataset) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py new file mode 100644 index 0000000000..aa605c30b6 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -0,0 +1,959 @@ +import logging +import math +import random +import torch +from torch import nn +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Tuple + + + +# After this many warnings about infinite gradients we'll die. +inf_grad_count = 0 +inf_grad_max_count = 20 + +class Madam(Optimizer): + r"""Madam is a modification of the Adam algorithm, with various changes + intended to support certain "common-sense" ideas and solve common + pathologies that can happen particularly in transformer-type models that + have multiplication of parameters (particularly, key and query matrices)-- + these can be vulnerable to "subspace loss" where, if you have any l2 + regularization, certain subspaces in the key/query space might get + regularized towards zero. We solve this with a special formula that + changes how the l2/weight-decay is done (see compute_l2_grad()). + I'll try to write the math down at some point. This formula only + applies to tensors that have at least two dimensions; for one-dimensional + tensors we simply won't do l2 regularization. + + One more thing-- there is a special pathology that can sometimes afflict + models like LSTMs, where a particular element of a minibatch experiences + gradient blowup in the backward pass. We'd like to identify such cases and + fix it somehow, e.g. by removing or scaling down the gradient for that + particular minibatch. We can identify and somewhat fix this by seeing that the + gradient norm (computed over all the parameters in a parameter group) is + much more than on previous minibatches, and limiting it to (the preceding + average step size times some constant). + + Like most optimization algorithms, for this to work well you need to + have an appropriate learning rate schedule, either decreasing with + time, or increasing (warm-up) and then decreasing. The LR schedule may + possibly need to decrease a little more aggressively than you would with + Adam, or at least have smaller values overall than Adam, because + the smaller parameters will mean the effective (relative) learning + rate is higher. + + This is modified from PyTorch's optim/adam.py + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + grad_norm_buffer_size (int, optional): Buffer size used in detecting + minibatches with unusually large gradients and scaling them down. + limit_grad_factor (float): factor by which we don't allow the + gradient to be greater than the average of previous gradients + (we'll scale the gradient down, over the whole param-group, + to enforce this). Must be greater than 1. Set to float('inf') + to disable norm clipping. + min_target_rms: A floor on the "target rms" of each Tensor, so + that Tensors that, when initialized, have less than this + rms value will have their target rms value floored to this + l2: True to enable l2 regularization + l2_period: You may set this to a value greater than one to save + computation by only periodically doing the l2 update. + We include a scaling factor in the formula so that, as far + as possible (for small learning rates) this shouldn't affect + the results. (Note: this probably isn't necessary to set, + since it turns out the update is quite fast, at least on GPU, + and the gradient clipping is actually more of a problem) + + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + + """ + + def __init__(self, params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + grad_norm_buffer_size: int = 8, + limit_grad_factor: float = 2.0, + min_target_rms: float = 0.05, + l2: bool = True, + l2_period: int = 1): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not (isinstance(grad_norm_buffer_size, int) and grad_norm_buffer_size > 1): + raise ValueError("Invalid grad_norm_buffer_size value: {}".format(grad_norm_buffer_size)) + if not limit_grad_factor > 1.0: + raise ValueError("Invalid limit_grad_factor: {}".format(limit_grad_factor)) + if not isinstance(l2, bool): + raise ValueError("Invalid l2 value: {}".format(l2)) + if not l2_period >= 1: + raise ValueError("Invalid l2_period value: {}".format(l2_period)) + defaults = dict(lr=lr, betas=betas, eps=eps, + grad_norm_buffer_size=grad_norm_buffer_size, + limit_grad_factor=limit_grad_factor, + l2=l2, l2_period=l2_period, + min_target_rms=min_target_rms) + super(Madam, self).__init__(params, defaults) + + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + beta1, beta2 = group['betas'] + grad_norm_buffer_size = group['grad_norm_buffer_size'] + limit_grad_factor = group['limit_grad_factor'] + min_target_rms = group['min_target_rms'] + + # The next 5 lists are part of the original Adam optimizer + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + + # The next 3 lists are not part of the original Adam optimizer. + target_rms_values = [] # relates to weight decay. Target root-mean-square + # values of the elements of each parameter + # we are optimizing + prev_norm_stats = [] # contains Tensor with 2 elements each, the sum + # of the [sum_squared, count] of + # this parameter on previous minibatches (up to + # grad_norm_buffer_size minibatches) + cur_grad_norms = [] # and `cur_grad_norms` contains the squared l2 + # norm norm of this step's gradient for this + # parameter, as a Tensor. + + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + grads.append(p.grad) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + # The things below are not part of original Adam, they are the Madam extension.. + state['target_rms'] = _get_target_rms(p, min_target_rms) + # grad_norm_buf is a rotating buffer containing (grad_norm**2, count), where + # count is 1 for grad_norms that are set and 0 for those that are not set because + # we're near step 0 or because they were infinite. + state['grad_norm_buf'] = torch.zeros(grad_norm_buffer_size, 2, device=p.device) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + target_rms_values.append(state['target_rms']) + + cur_step = state['step'] + if limit_grad_factor != float('inf'): + grad_norm_buf = state['grad_norm_buf'] + cur_grad_norm = (p.grad ** 2).sum() # actually squared nom + prev_mean_norm = grad_norm_buf.sum(0) # prev_mean_norm is a Tensor [ tot_norm_squared, count ] + grad_norm_buf[cur_step % grad_norm_buffer_size][0] = cur_grad_norm + grad_norm_buf[cur_step % grad_norm_buffer_size][1].fill_(1.0) + prev_norm_stats.append(prev_mean_norm) + cur_grad_norms.append(cur_grad_norm) + + # update the steps for each param group update + cur_step += 1 + state['step'] = cur_step + # record the step after step update + state_steps.append(cur_step) + + if limit_grad_factor != float('inf'): + self._apply_grad_norm_clipping(group['params'], + prev_norm_stats, cur_grad_norms, grads, + limit_grad_factor, grad_norm_buffer_size) + + _madam(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + target_rms_values, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + eps=group['eps'], + l2=group['l2'], + l2_period=group['l2_period']) + + return loss + + + def _apply_grad_norm_clipping(self, + params_list, + prev_norm_stats: List[Tensor], + cur_grad_norms: List[Tensor], + grads: List[Tensor], + limit_grad_factor: float, + grad_norm_buffer_size: int) -> None: + """ + This function applies gradient norm clipping for this parameter group if this + minibatch has substantially larger gradients in this param group than + recent minibatches. The idea is to catch cases like where an LSTM + happens to blow up in the backward pass, or some code bug causes very + large or infinite gradients on a particular minibatch; so we scale + down any very large gradients and zero infinite ones. + + Args: + params_list: some kind of iterable or list of params in this group + prev_norm_stats: a list which, for each parameter in this group + with a grad, contains a Tensor with 2 elements each, containing + # the [sum, count] of up to `grad_norm_buffer_size` + # norms of this parameter on previous minibatches; + cur_grad_norms: a list of Tensor containing, for each parameter in this group, + the norm of this step's gradient for this parameter. + grads: List of gradients with the same order as prev_norm_stats and + cur_grad_norms + limit_grad_factor: a float >1.0 (e.g. 4.0) that dictates + how-much-larger-than-average gradients we allow before clipping. + grad_norm_buffer_size: an int that determines the rolling buffer size over which + we store gradient norms + """ + num_params = len(prev_norm_stats) + assert len(grads) == num_params + + all_prev_norm_stats, all_cur_grad_norms = _to_device('cpu', + torch.stack(prev_norm_stats), + torch.stack(cur_grad_norms)) + assert all_prev_norm_stats.shape == (num_params, 2) + assert all_cur_grad_norms.shape == (num_params,) + + # divide totals by counts (i.e. counts of iterations were we stored + # a finite grad) + all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1] + # prev_norm and cur_norm are floats, they are actually squared norms. + prev_norm = all_prev_grad_norms.sum().item() + cur_norm = all_cur_grad_norms.sum().item() + + if prev_norm - prev_norm != 0.0: + # There were zero counts; fix this by using the current grad norm + # for affected parameters, and recompute all_prev_grad_norms and + # prev_norm. + for i in range(num_params): + if all_prev_norm_stats[i][1] == 0.0: + # if count is 0 and cur norm is finite, use cur norm as our estimate + # of previous norms. This would only be useful if some but not + # all params were in this situation of having no previous estimates. + cur = all_cur_grad_norms[i] + if cur - cur == 0.0: # finite.. + all_prev_norm_stats[i][0] = cur + all_prev_norm_stats[i][1] = 1.0 + else: + # 0.0 is a default; likely won't matter, as if we + # get infinite `cur`, we'll abandon this minibatch. + all_prev_norm_stats[i][0] = 0.0 + all_prev_grad_norms = all_prev_norm_stats[:,0] / all_prev_norm_stats[:,1] + prev_norm = all_prev_grad_norms.sum().item() + + # Deal with infinite gradients. + if cur_norm - cur_norm != 0: # cur_norm is infinite or NaN + global inf_grad_count + logging.warning(f'Infinite gradient-norm detected (cur/prev: {cur_norm}/{prev_norm}): will ' + f'zero grad ({inf_grad_count}/{inf_grad_max_count} times until dying)') + inf_grad_count += 1 + if inf_grad_count >= inf_grad_max_count: + assert 0, "Reached max count of infinite gradient-norm stats" + # Zero all gradients in this group + for g in grads: + g[:] = 0. + # .. and zero the stored gradient norms in grad_norm_buf (so + # that infinities don't ruin our stats of previous batches) + for p in params_list: + if p.grad is not None: + state = self.state[p] + grad_norm_buf = state['grad_norm_buf'] + # cur_step is the location where we would have written the grad_norm. + # We didn't check if it was infinity before, because we didn't want to + # incur lots of GPU->CPU transfers. + cur_step = state['step'] - 1 + # Remove this 'bad' step from the buffer. + grad_norm_buf[cur_step % grad_norm_buffer_size][:] = 0.0 + else: + # cur_norm is finite. Check whether we have to clip this iteration's grad. + # we always remove infinities/NaNs from the buffer, so prev_norm should not + # be infinite or NaN. + assert prev_norm - prev_norm == 0.0 + # cur_norm and prev_norm are actually squared norms, so we need to + # square limit_grad_factor.. + limit_grad_factor2 = limit_grad_factor ** 2 + if cur_norm > prev_norm * limit_grad_factor2: + grad_factor2 = (prev_norm * limit_grad_factor2) / cur_norm + grad_factor = grad_factor2 ** 0.5 + cur_norm_f, prev_norm_f, grad_factor_f = ('%.2g' % cur_norm, '%.2g' % prev_norm, + '%.2g' % grad_factor) + logging.warning(f'Gradient norm exceeds average of last {grad_norm_buffer_size} ' + f'gradients times {limit_grad_factor}: cur/prev {cur_norm_f}/{prev_norm_f}: ' + f'scaling it by {grad_factor_f}.') + for g in grads: + g[:] *= grad_factor + # .. and scale down the stored gradient norms in grad_norm_buf, to + # avoid the bound getting too loose too quickly. + for p in params_list: + if p.grad is not None: + state = self.state[p] + grad_norm_buf = state['grad_norm_buf'] + cur_step = state['step'] - 1 + # the buffer contains squared norms, so multiply by grad_factor2 + grad_norm_buf[cur_step % grad_norm_buffer_size][0] *= grad_factor2 + + +def _to_device(device, *args): + """ + Transfers a tuple of Tensors from one device to another, using a single transfer. Must have + same dtype but may have different shapes. + E.g. + (cpu_tensor_a, cpu_tensor_b) = _to_device('cpu', gpu_tensor_a, gpu_tensor_b) + """ + if device == args[0].device: + return args + else: + arg0 = args[0] + combined_src = torch.cat([ x.reshape(-1) for x in args ]) + combined_dest = combined_src.to(device) + dests = [] + offset = 0 + for src in args: + numels = src.numel() + dests.append(combined_dest[offset:offset+numels].reshape(src.shape)) + offset += numels + return tuple(dests) + + + +def _get_target_rms(x: Tensor, min_target_rms: float) -> Tensor: + """ + Returns Tensor with one element, representing a target root-mean-square + value of elements of x, that we consider "reasonable", and will use a + as a "target rms" in our modified weight-decay formula. It returns + the maximum of the current RMS of the values of x, and `min_target_rms`, + as a Tensor on the same device as x. + """ + with torch.no_grad(): + # `norm` is the 2-norm of x currently (and this function should be + # called right after parameter initialization) + rms = ((x ** 2).sum() / x.numel()).sqrt() + largest_dim = max(list(x.shape)) + numel = x.numel() + if min_target_rms > 0.0: + rms = rms.clamp(min=min_target_rms) + if x.ndim > 1 and __name__ == '__main__': # will only be used for x.ndim > 1. + print("Target rms = ", rms) # Print this in testing only. + return rms + + +def _madam(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[int], + target_rms_values: List[Tensor], + *, + beta1: float, + beta2: float, + lr: float, + eps: float, + l2: bool, + l2_period: int): + r"""This is a modification of adam() from torch's optim/_functional.py. + + It has been modified to: + (i) remove the amsgrad option; this shouldn't be as necessary due to + the adaptive gradient norm clipping we have added + (ii) add our special formula for l2 regularization. This doesn't have + any tunable parameters, other than the target standard deviation + of the elements of the tensor (which is passed in as target_rms). + Args: + params: list of Tensor, containing the parameters to be optimized + grads: list of Tensor, containing the gradients corresponding to + each of the params (grads[i] should correspond to params[i].grad, + although it may have undergone gradient clipping). + exp_avgs: list of Tensor, containing tensors with the same dimensions + as params and grads, that contain the moving-averages of + `grads`. + exp_avg_sqs: list of Tensor, containing tensors with the same dimensions + as params and grads, that contain the moving-averages of + `grads ** 2`. + state_steps: list of int, containing the step for each parameter (step >= 1) + target_rms_values: list of Tensor with one element each, containing the + target root-mean-square values of each parameter tensor in `params` + l2: a bool, where if true we will activate the l2 regularization + formula. + l2_period: an integer that determines how often (i.e. every how many + minibatches) we apply the l2 update. We include a scaling factor + so that as far as possible the result will not be too sensitive + to the value of this. + + beta1: decay factor for gradients, e.g. 0.9 + beta2: decay factor for gradients squared, e.g. 0.999 + lr: learning rate, e.g. 0.0001 + eps: a small constant used to prevent division by zero, e.g. 1.0e-8 + + See :class:`~torch.optim.Adam` for details. + """ + assert len(params) == len(grads) == len(state_steps) == len(exp_avgs) == len(exp_avg_sqs) == len(target_rms_values) + + for i, param in enumerate(params): + + grad = grads[i] + + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step = state_steps[i] + target_rms = target_rms_values[i] + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + + do_l2 = param.ndim > 1 and l2 and step % l2_period == 0 + + if do_l2: + # This represents just the "noise term" of the gradient, i.e. the grad minus the + # running mean. We'll later divide by denom. + cur_grad_noise = (grad - exp_avg) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + step_size = lr / bias_correction1 + + if not do_l2: + param.addcdiv_(exp_avg, denom, value=-step_size) + else: + # We can treat "pseudo_grad" as if it were a gradient (even though it's + # actually a gradient times a per-element learning rate). The analysis + # that we used to figure out what the l2 should be did not use the fact + # that the gradients were actually gradients, it simply analyzed it as a + # quantity that can be treated as close to zero-mean and with a certain + # structure of variance, and added to the param with the formula: + # + # param -= step_size * grad + # + # The original analysis assumed the gradients were independent from frame + # to frame; in fact these are not, but the important difference can be captured + # in a scalar `grad_scale` that expresses the scale of pseudo_grad relative + # to the independent gradients that we are effectively adding on each frame + # (but with a delay). + + pseudo_grad = exp_avg / denom + cur_pseudo_grad = cur_grad_noise / denom + + # grad_scale expresses the expected size of cur_pseudo_grad relative to the + # original grads if we had not done the moving-average; it is the sqrt of + # the sum of the squares of coefficients of previous gradients: + # c_n = (1-beta1) beta1^n, for + # n = 0, 1, .. + # .. plus one which is the sumsq of the coefficient of 'grad' itself in + # (grad - exp_avg). + # It is relevant that the sum of the coefficients (i.e. not squared) is 1; + # if this were not so we'd have to incorporate that into the formula for l2. + grad_scale = (((1 - beta1)**2) / (1 - beta1**2) + 1) ** 0.5 + + with torch.no_grad(): + l2_grad = _compute_l2_grad(param, cur_pseudo_grad, target_rms, + rho=step_size, grad_scale=grad_scale, + period_scale=l2_period, + eps=eps, safe=True) + + # TODO: could alternate computing l2 on only, say, odd frames, and scale it + # up by 2, to save time. + param.add_(pseudo_grad + l2_grad, alpha=-step_size) + + + +def _view_as_matrix(x: Tensor, dim: int) -> Tensor: + """ + Returns a Tensor of shape (n, x.shape[dim]), where n is the product + of the sizes of the other dimensions of x. This may involve a copy, + if x cannot be reshaped in this way. + """ + ndim = x.ndim + assert ndim > 1 and dim >= 0 and dim < ndim + # Move the dim to the last position in x.. + if dim != ndim - 1: + x = x.transpose(dim, ndim - 1) + return x.reshape(-1, x.shape[-1]) + + +def _outer_product(x: Tensor, dim: int) -> Tensor: + """ + Returns a Tensor of shape (x.shape[dim], x.shape[dim]) formed by + summing the outer products of all the vectors in x of size + `x.shape[dim]`, that we get by indexing x with all tuples of dimensions + on other axes. E.g. if x is a matrix and dim == 0, this would + be torch.matmul(x, x.transpose(0, 1)). + + Note: x must have at least 2 dimensions, x.ndim >= 2. + """ + x = _view_as_matrix(x, dim) + return torch.matmul(x.transpose(0, 1), x) + +def _multiply_on_dim(x: Tensor, m: Tensor, dim: int) -> Tensor: + """ + Multiplies x by the matrix m which must be of shape: + (x.shape[dim], n)), with `dim` as the dimension/axis on + x to be multiplied. + + Caution: result may not have the same layout/strides as x, + although it will have the same shape. + + Args: + x: Tensor to be multiplied; must have ndim >= 2 + m: Symmetric matrix to multiply x by; must have + m.shape == (x.shape[dim], x.shape[dim]) + dim: Dimension of x to multiply on, with 0 <= dim < x.ndim + Return: + The matrix product, of the same shape as + x, except with the size on dimension `dim` being n. + """ + ndim = x.ndim + if dim != ndim - 1: + x = x.transpose(dim, ndim - 1) + ans = torch.matmul(x, m) + if dim != ndim - 1: + # Swap the dimensions back to what they were originally. + ans = ans.transpose(dim, ndim - 1) + return ans + + +def _multiply_product_combined(l2: Tensor, grad: Tensor, dim: int, + need_grad_sumsq: bool): + """ + This function is an optimized version of the following code: + outer_prod = _outer_product(grad, dim) + l2 = _multiply_on_dim(l2, outer_prod, dim) + if dim == 0: # could choose any dim for this + grad_sumsq = torch.trace(outer_prod) + Args: + l2: The l2 matrix which starts out as the parameter tensor x, must have >= 2 diims + grad: The gradient tensor (or a gradient-like quantity); must + have same shape as l2. + dim: The dimension of l2 and grad that we want this to + act on, with 0 <= dim < l2.ndim. We multiply l2, on + this dim, by a symmetric quantity of shape + (l2.shape[dim], l2.shape[dim]), that is formed + by a product and sum on grad (this is a matrix + product, if there are 2 axes). + Returns: + Returns (l2, grad_sumsq), where l2 is the result of + multiplying l2 by the product mentioned above, and + grad_sumsq is either None, or a Tensor representing + the sum-of-squares of `grad`; for at least one + dim with 0 <= dim < l2.ndim, we guarantee to + return such a Tensor. + """ + grad = _view_as_matrix(grad, dim) + if grad.shape[1] <= grad.shape[0]: + # Minimize the size of the intermediate product, which will probably well reflect + # the compute time since memory access can be limiting on CUDA.a + grad_product = torch.matmul(grad.transpose(0, 1), grad) + l2 = _multiply_on_dim(l2, grad_product, dim) + if need_grad_sumsq: + grad_sumsq = torch.trace(grad_product) + else: + grad_sumsq = None + return (l2, grad_sumsq) + else: + l2 = _multiply_on_dim(l2, grad.transpose(0, 1), dim) + l2 = _multiply_on_dim(l2, grad, dim) + # This branch does not compute grad_sumsq, but we're bound to + # take the other branch on at least one occasion. + return (l2, None) + + + +def _compute_l2_grad(x: Tensor, grad: Tensor, target_stddev: float, rho: float, + grad_scale: float = 1.0, period_scale: int = 1, + eps: float = 1.0e-08, + safe: bool = True) -> Tensor: + """ + Returns the l2 gradient of x, which will be added to 'grad'. + This is a more principled replacement for the typical l2 regularization + formula where we do: + grad += weight_decay * x. + (Note: this must only be called if x.ndim >= 2). + + For x with 2 axes, we instead do this: + + grad += (rho / (2*target_stddev**2)) * (grad grad^T) x (grad^T grad) / trace(grad^T grad), + + where the implicit multiplication above refers to matrix multiplication; note, x means + the variable x. We'll have to write the justification of this, which is a little + complicated, separately; it has to do with using exactly the amount of l2 in each + subspace of each dimension of x, to to cancel out the gradient noise. + + Args: + x: parameter to be updated. MUST HAVE x.ndim >= 2. + grad: Gradient for x on this iteration (or at least, something that + is treated like a gradient in the update formula) +target_stddev: The target standard deviation (uncentered), of elements of x. + This is our estimate of what standard deviation these elements would + have in a well-trained model; it is set by some kind of heuristic. + rho: The learning rate we are going to use, as in: x -= (grad + l2) * rho. + grad_scale: A scale whereby the caller asserts that `grad` is some + quantity that is distributed like the real + gradient times `grad_scale` (this is useful when the provided `grad` + is really a moving average gradient). Because the l2 term's magnitude + is proportional to the gradient squared, we need to divide it by the + square of grad_scale, so this function uses 1/grad_scale^2 as a scaling + factor. +period_scale: An integer scale that we use to compensate for the fact that this + weight decay is only applied periodically, once every + `period_scale` minibatches. Accordingly, we make the l2 term + that many times larger. + eps: A small constant used to avoid division by zero + safe: If true, use a safe version of the formula that checks for + 'overshoot' of l2 regularization and fixes the issue (might + be an issue for models that are getting unstable or have high + learning rate) + + + Returns: + Returns l2 pseudo-gradient (term to be added to `grad`). + """ + assert x.shape == grad.shape + assert x.ndim >= 2 + + l2 = x + grad_sumsq = None + num_ignored_dims = 0 # for an optimization for when size=1 on some dim. + for dim in range(x.ndim): + # The code below is an optimization of the following few lines, + # which were perhaps easier to understand: + # outer_prod = _outer_product(grad, dim) + # l2 = _multiply_on_dim(l2, outer_prod, dim) + # if dim == 0: # could choose any dim for this + # grad_sumsq = torch.trace(outer_prod) + if x.shape[dim] <= 1: + num_ignored_dims += 1 + continue + (l2, maybe_grad_sumsq) = _multiply_product_combined(l2, grad, dim, + grad_sumsq is None) + if maybe_grad_sumsq is not None: + grad_sumsq = maybe_grad_sumsq + if grad_sumsq is None: + # We shouldn't reach here, except if at some point we start calling this + # code for tensors with ndim <= 1, or with numel() == 1. + grad_sumsq = (grad ** 2).sum() + + # l2 is the amount of l2, we'll subtract this from x, as in: + # x -= rho * (grad + l2). + + factor = rho * period_scale / (2.0 * (target_stddev * grad_scale)**2) + l2 = l2 * (factor / (grad_sumsq ** (x.ndim - 1 - num_ignored_dims) + eps)) + + if safe and rho > 0: + #x2_sum = (x ** 2).sum() + l2_sum = (l2 ** 2).sum() * (rho * rho) + cross_sum = (x * l2).sum() * rho + alpha = cross_sum / (l2_sum + eps) + # We want to minimize the sum-of-squares of (x - alpha * rho * l2), where alpha + # is a constant in [0,1] that we are about to estimate, intended to prevent + # instability by scaling down our weight decay formula. Right now (and treating + # things as if they were scalars for brevity): + # x2_sum = x * x + # l2_sum = rho * rho * l2 * l2 + # cross_sum = x * rho * l2 + # We want to minimize the sum-sq of (x - alpha * rho * l2), + # i.e. we want to choose alpha to minimize: + # x2_sum - 2 * alpha * cross_sum + alpha^2 * l2_sum + # d/dalpha of this, is: + # -2*cross_sum + 2 * alpha * l2_sum + # and setting this to zero and solving for alpha, we have: + # alpha = cross_sum / l2_sum. + # If it turns out that alpha >= 1, then we just use alpha=1 + # (the original formula), as there is no problem with + # instability/overshoot. + l2.mul_(alpha.clamp(max=1.0)) + if random.random() < 0.001 and alpha < 1.0: + logging.info(f'madam optimizer: alpha={alpha}, shape={tuple(x.shape)}') + return l2 + + + +class Moam(object): + """ + Implements Moam optimizer. This is a modified version of the Noam optimizer + which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, + but changed to use Madam (see above) instead of Adam as the base optimizer. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + Caution: you probably want to set 'factor' to a smaller value than you would typically + use for a corresponding Noam optimizer, because Moam does a kind of l2 regularization which + keeps the parameters fairly small, so the relative changes in model parameters + will be larger than Noam, for any given learning rate. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + model_size: attention dimension of the transformer model + factor: learning rate factor, that multiplies the output of the + formula based on model size + warm_step: number of warmup steps before the learning rate starts to decrease + (it increases until this point). + min_target_rms: this is a parameter of the Madam optimizer; it represents a floor + on the "target root-mean-square value" that is used when the initialization + of a tensor is zero or below this value. It may be worth optimizing. + Don't worry about tensors with fewer than 2 dimensions when setting this, + these are not subject to our l2 formula. + limit_grad_factor: you can set this to a finite value, e.g. 2.0, to activate + a mechanism that limits the norms of larger-than-usual gradients. + This seems to cause a slowdown, likely due to GPU->CPU transfers. + l2_period: mechanism to improve the optimization speed, by only applying the l2 + regularization (which is a complicated formula) every this-many + minibatches. E.g. can set it to 2 or 4. + """ + + def __init__(self, params, model_size: int = 256, + factor: float = 2.0, warm_step: int = 25000, + min_target_rms: float = 0.05, + limit_grad_factor: float = float('inf'), + l2_period: int = 1) -> None: + """Construct an Noam object.""" + self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + + +class TestModel(torch.nn.Module): + """Class for testing the Madam optimizer""" + def __init__(self): + super(TestModel, self).__init__() + self.first_layers = torch.nn.Sequential( + torch.nn.Linear(100, 200), + torch.nn.ReLU(), + torch.nn.Linear(200, 300), + torch.nn.ReLU()) + self.conv1 = torch.nn.Conv1d(in_channels=300, out_channels=200, + kernel_size=1) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv1d(in_channels=200, out_channels=250, + kernel_size=3) + + + def forward(self, x): + # from (B, T, 100) to (B, T, 200) + x = self.first_layers(x) + # B, T, C -> B, C, T + x = x.transpose(1, 2) + x = self.conv2(self.relu(self.conv1(x))) + # B, C, T -> B, T, C + x = x.transpose(1, 2) + return x + +def test_madam(): + print("Testing Madam optimizer") + global inf_grad_max_count + inf_grad_max_count = 200 + if torch.cuda.is_available(): + devices_and_l2 = [(torch.device('cuda'), True), + (torch.device('cuda'), False)] + #(torch.device('cpu'), True), + #(torch.device('cpu'), False)] + else: + devices_and_l2 = [(torch.device('cpu'), True), + (torch.device('cpu'), False)] + + + for (device, l2) in devices_and_l2: + model = TestModel().to(device) + # min_target_rms=0.01 is for testing, so the target equals the initial RMS + # and we can more easily tell whether our update has the desired effect. + # I also tested this with betas=(0.1, 0.98), to check that the effect of + # `grad_scale` was correct (it only makes much difference for small beta). + optimizer = Madam(model.parameters(), lr=0.0005, betas=(0.9, 0.98), + l2=l2, min_target_rms=0.01, l2_period=1) + #optimizer = torch.optim.Adam(model.parameters()) + + def get_elems_rms(x: Tensor) -> Tensor: + return ((x ** 2).sum() / x.numel()).sqrt().item() + + for i in range(1000): + if i % 100 == 0: + rms_values = (get_elems_rms(model.first_layers[0].weight), + get_elems_rms(model.first_layers[2].weight), + get_elems_rms(model.conv1.weight), + get_elems_rms(model.conv2.weight)) + print(f"Iter {i}, l2={l2}, device={device}: stddevs = {rms_values} ") + B = 4 + T = 20 + x = torch.randn(B, T, 100).to(device) + y = model(x) + yderiv = torch.randn_like(y) + if i % 190 <= 3 and i > 0: + yderiv *= 100.0 + if i % 550 == 0 and i > 0: + yderiv *= float('inf') + + y.backward(gradient=yderiv) + optimizer.step() + model.zero_grad() + print("") + +def test_moam(): + print("Testing Moam optimizer") + model = TestModel() + # min_target_rms=0.01 is for testing, so the target equals the initial RMS + # and we can more easily tell whether our update has the desired effect. + optimizer = Moam(model.parameters(), factor=1.0, warm_step=300, + min_target_rms=0.01) + + + def get_elems_rms(x: Tensor) -> Tensor: + return ((x ** 2).sum() / x.numel()).sqrt().item() + + for i in range(1000): + if i % 100 == 0: + rms_values = (get_elems_rms(model.first_layers[0].weight), + get_elems_rms(model.first_layers[2].weight), + get_elems_rms(model.conv1.weight), + get_elems_rms(model.conv2.weight)) + print(f"Iter {i} (Moam): stddevs = {rms_values} ") + B = 4 + T = 20 + x = torch.randn(B, T, 100) + y = model(x) + yderiv = torch.randn_like(y) + if i % 190 <= 3 and i > 0: + yderiv *= 100.0 + if i % 550 == 0 and i > 0: + yderiv *= float('inf') + + y.backward(gradient=yderiv) + optimizer.step() + model.zero_grad() + print("") + + + +def test_to_device(): + if not torch.cuda.is_available(): + return + a_gpu = torch.ones(1,2,3,4, device='cuda') + b_gpu = torch.zeros(3,8, device='cuda') + (a_cpu, b_cpu) = _to_device('cpu', a_gpu, b_gpu) + print("a_cpu,b_cpu = ", a_cpu, b_cpu) + (a_gpu2, b_gpu2) = _to_device('cuda', a_cpu, b_cpu) + print("a_gpu2,b_gpu2 = ", a_gpu2, b_gpu2) + +# Caution: this testing code is not very automated, it reqires looking at the output to +# make sure it looks right. The main thing is that with l2=True, the printed stddevs stay close +# to the "Target rms" values, which are printed out; while with l2=False, the stddevs +# increase to significantly higher than that. +# +# The test of the Moam optimizer is mainly to make sure it runs; the scale of the +# gradients, and the learning rate, are such that one of the rms's stays quite a bit +# above the target value, i.e. (0.047, 0.044, 0.047), vs. targets of +# (0.057, 0.04, 0.019), I think this has to do with the alpha<1 stability mechanism being +# activated, the l2 does have an effect, as I verified by changing the code to set +# l2=False. +def main(): + # Set number of threads to 1, or Torch can do weird things that make it extremely slow. + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + #test_to_device() + random.seed(0) + torch.random.manual_seed(0) + test_madam() + #test_moam() + + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py new file mode 100755 index 0000000000..85d671ecc6 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Daniel Povey) +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import dataset # from . +import madam # from . +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from conformer import MaskedLmConformer +from lhotse.utils import fix_random_seed +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from madam import Moam + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist + +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - exp_dir: It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - num_epochs: Number of epochs to train. + + - num_valid_batches: Number of batches of validation data to use each + time we compute validation loss + + - symbols_per_batch: Number of symbols in each batch (sampler will + choose the number of sentences to satisfy this contraint). + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + """ + params = AttributeDict( + { + "exp_dir": Path("conformer_lm/exp_1"), + "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), + "num_tokens": 5000, + "blank_sym": 0, + "bos_sym": 1, + "eos_sym": 1, + "start_epoch": 0, + "num_epochs": 20, + "num_valid_batches": 100, + "symbols_per_batch": 10000, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + "reset_interval": 200, + "valid_interval": 3000, + "beam_size": 10, + "accum_grad": 1, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + "lr_factor": 2.0, + "warm_step": 20000, + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + batch: Tuple, + is_training: bool, +): + + """ + Compute training or validation loss given the model and its inputs + (this corresponds to log-prob of the targets, with weighting + of 1.0 for masked subsequences + (including padding blanks), and something smaller, e.g. 0.25, + for non-masked positions (this is not totally trivial due to + a small amount of randomization of symbols). + + This loss is not normalized; you can divide by batch[4].sum() + to get a normalized loss (i.e. divide by soft-count). + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of MaskedLmConformer in our case. + batch: + A batch of data, actually a tuple of 5 tensors (on the device), as returned + by collate_fn in ./dataset.py. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + + Returns: + Returns the loss as a scalar tensor. + """ + (masked_src_symbols, src_symbols, + tgt_symbols, src_key_padding_mask, tgt_weights) = batch + + with torch.set_grad_enabled(is_training): + memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) + tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols, + tgt_symbols, src_key_padding_mask) + loss = (tgt_nll * tgt_weights).sum() + + assert loss.requires_grad == is_training + + return loss + + +def compute_validation_loss( + device: torch.device, + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> None: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = 0.0 + tot_frames = 0.0 + for batch_idx, batch in enumerate(valid_dl): + batch = tuple(x.to(device) for x in batch) + + # `batch` is actually a tuple.. we'll unpack it later. + loss = compute_loss(model, batch, is_training=False) + num_frames = batch[4].sum() + + assert loss.requires_grad is False + assert ctc_loss.requires_grad is False + assert att_loss.requires_grad is False + + loss_cpu = loss.detach().cpu().item() + num_frames_cpu = num_frames.cpu().item() + + tot_loss += loss_cpu + tot_frames += num_frames_cpu + + + if world_size > 1: + s = torch.tensor( + [tot_loss, tot_frames], + device=loss.device, + ) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + (tot_loss, tot_frames) = s.cpu().tolist() + + params.valid_loss = tot_loss / tot_frames + + if params.valid_loss < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = params.valid_loss + + +def train_one_epoch( + device: torch.device, + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + device: + The device to use for training (model must be on this device) + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() # training mode + + tot_loss = 0.0 # sum of losses over all batches + tot_frames = 0.0 # sum of frames over all batches + + params.tot_loss = 0.0 + params.tot_frames = 0.0 + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch = tuple(x.to(device) for x in batch) + + loss = compute_loss( + model=model, + batch=batch, + is_training=True, + ) + + optimizer.zero_grad() + loss.backward() # We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total + # gradient scale so this should not matter. + # clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + loss_cpu = loss.detach().cpu().item() + num_frames_cpu = batch[4].sum().cpu().item() + + tot_loss += loss_cpu + tot_frames += num_frames_cpu + + params.tot_frames += num_frames_cpu + params.tot_loss += loss_cpu + + tot_avg_loss = tot_loss / tot_frames + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, " + f"total avg loss: {tot_avg_loss:.4f}, " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_frames = 0.0 # sum of frames over all batches + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + compute_validation_loss( + device=device, + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info( + f"Epoch {params.cur_epoch}, " + f"valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " + f"best valid epoch: {params.best_valid_epoch}" + ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) + + params.train_loss = params.tot_loss / params.tot_frames + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + num_tokens = params.num_tokens + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + + logging.info("About to create model") + model = MaskedLmConformer( + num_classes=params.num_tokens, + d_model=params.attention_dim, + nhead=params.nhead, + num_decoder_layers=params.num_decoder_layers, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Moam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + train,test = dataset.load_train_test_lm_dataset(params.lm_dataset) + + collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym, + eos_sym=params.eos_sym, + blank_sym=params.blank_sym, + mask_proportion=0.15, + padding_proportion=0.15, + randomize_proportion=0.05, + inv_mask_length=0.25, + unmasked_weight=0.25)) + + train_sampler = dataset.LmBatchSampler(train, + symbols_per_batch=params.symbols_per_batch, + world_size=world_size, rank=rank) + test_sampler = dataset.LmBatchSampler(test, + symbols_per_batch=params.symbols_per_batch, + world_size=world_size, rank=rank) + + train_dl = torch.utils.data.DataLoader(train, + batch_sampler=train_sampler, + collate_fn=collate_fn) + valid_dl = torch.utils.data.DataLoader(test, + batch_sampler=test_sampler, + collate_fn=collate_fn) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + device=device, + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 7711fba8670a1815d9092ab3e40c7bc93277df2c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 22:40:23 +0800 Subject: [PATCH 14/26] Fix bugs; first version that is running successfully. --- egs/librispeech/ASR/conformer_lm/train.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 85d671ecc6..72bc654ea7 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -141,7 +141,7 @@ def get_params() -> AttributeDict: "start_epoch": 0, "num_epochs": 20, "num_valid_batches": 100, - "symbols_per_batch": 10000, + "symbols_per_batch": 5000, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, @@ -417,13 +417,13 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, batch {batch_idx}, " f"batch avg loss {loss_cpu/num_frames_cpu:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" - ) + f"batch shape: {tuple(batch[0].shape)}") + if tb_writer is not None: tb_writer.add_scalar( "train/current_loss", - loss_cpu / params.train_frames, + loss_cpu / num_frames_cpu, params.batch_idx_train, ) tb_writer.add_scalar( @@ -549,7 +549,7 @@ def run(rank, world_size, args): collate_fn=collate_fn) for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) + train_sampler.set_epoch(epoch) cur_lr = optimizer._rate if tb_writer is not None: From 9576d6574fdbbdd89b1b89da16ac9edd6bbdefce Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 23 Aug 2021 23:45:03 +0800 Subject: [PATCH 15/26] Various bug fixes --- egs/librispeech/ASR/conformer_lm/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 72bc654ea7..66602ea1d3 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -140,7 +140,7 @@ def get_params() -> AttributeDict: "eos_sym": 1, "start_epoch": 0, "num_epochs": 20, - "num_valid_batches": 100, + "num_valid_batches": 200, "symbols_per_batch": 5000, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -288,8 +288,9 @@ def compute_loss( with torch.set_grad_enabled(is_training): memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) - tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols, - tgt_symbols, src_key_padding_mask) + decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll + tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols, + tgt_symbols, src_key_padding_mask) loss = (tgt_nll * tgt_weights).sum() assert loss.requires_grad == is_training @@ -312,6 +313,8 @@ def compute_validation_loss( tot_loss = 0.0 tot_frames = 0.0 for batch_idx, batch in enumerate(valid_dl): + if batch_idx == params.num_valid_batches: + break batch = tuple(x.to(device) for x in batch) # `batch` is actually a tuple.. we'll unpack it later. @@ -319,8 +322,6 @@ def compute_validation_loss( num_frames = batch[4].sum() assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False loss_cpu = loss.detach().cpu().item() num_frames_cpu = num_frames.cpu().item() From e6eefeba882763c6c4e2cdab63a3215a1671cd82 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Aug 2021 14:50:49 +0800 Subject: [PATCH 16/26] Changes to dataset to prevent OOM on batches with short sentences --- egs/librispeech/ASR/conformer_lm/dataset.py | 46 +++++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 3074d10999..fcf7f39f01 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -2,6 +2,7 @@ import torch.distributed as dist import k2 import _k2 +import logging import sentencepiece as spm from pathlib import Path from typing import Optional, List, Tuple, Union @@ -333,6 +334,7 @@ def collate_fn(sentences: List[List[int]], """ assert blank_sym not in [bos_sym, eos_sym] max_sent_len = max([ len(s) for s in sentences]) + #logging.info(f"Sentence lengths: {[ len(s) for s in sentences]}") typical_mask_and_pad = int(max_sent_len * (mask_proportion + padding_proportion)) @@ -635,17 +637,22 @@ class LmBatchSampler(torch.utils.data.Sampler): """ def __init__(self, dataset: LmDataset, symbols_per_batch: int, - quadratic_constant: float = 0.005, + length_ceil: float = 200.0, + length_floor: float = 4.0, world_size: Optional[int] = None, rank: int = None, - seed: int = 0): + seed: int = 0, + delay_init: bool = False): """ Constructor documentation: dataset: the LmDataset object that we are sampling from. This class does not retain a reference to the LmDataset. symbols_per_batch: The number of BPE symbols desired in each minibatch - quadratic_constant: After the sentence length gets more than about - 1.0/quadratic_constant, the batch size will start decreasing + length_floor: When the sentence length gets less than about this much, + the batch size stops increasing inversely with sentence + length. Prevent OOM on batches with short sentences. + length_ceil: After the sentence length gets more than about + this much, the batch size will start decreasing as 1/(sentence-length^2). This is a mechanism to avoid excessive memory consumption in transformers, when sentence length gets long. @@ -654,10 +661,17 @@ class does not retain a reference to the LmDataset. rank: The rank of this sampler/process for distributed operation; if None, will be worked out from torch.distributed. seed: The random seed + delay_init: If true, will omit calling self.set_epoch(0) at the + end of the __init__ function. In this case the caller + must call set_epoch(0). [Setting this option is necessary + to work with data-loader worker processes plus DDP, since + set_epoch() will use ddp, which I believe is a no-no prior + to initializing data-loaders.] """ self.seed = seed self.symbols_per_batch = symbols_per_batch - self.quadratic_constant = quadratic_constant + self.length_floor = length_floor + self.quadratic_constant = 1.0 / length_ceil self._maybe_init_distributed(world_size=world_size, rank=rank) # a configuration constant we don't expose. @@ -698,8 +712,20 @@ class does not retain a reference to the LmDataset. # `data_indexes` above (this is not stored, as we know the formula). self.sentence_lengths = sentence_lengths - self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes - + if not delay_init: + self.set_epoch(0) # this is responsible for setting self.sorted_data_indexes + + def _sync_sizes(self, device: torch.device = torch.device('cuda')): + # Calling this on all copies of a DDP setup will sync the sizes so that + # all copies have the exact same number of batches. I think + # this needs to be called with the GPU device, not sure if it would + # work otherwise. + if self.world_size > 1: + min_size = torch.tensor([len(self.batch_indices)], device=device, dtype=torch.int64) + dist.all_reduce(min_size, op=dist.ReduceOp.MIN) + min_size = min_size.to('cpu').item() + logging.info(f"world_size={self.world_size}, rank={self.rank}: reducing batch indices from {len(self.batch_indices)} to {min_size}") + self.batch_indices = self.batch_indices[0:min_size] def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]): if world_size is not None: @@ -714,6 +740,7 @@ def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int] self.rank = dist.get_rank() if rank is None else rank assert self.rank < self.world_size + def set_epoch(self, epoch: int): """ Must be called at the beginning of each epoch, before initializing the DataLoader, @@ -727,7 +754,7 @@ def set_epoch(self, epoch: int): # This mechanism regulates the batch size so that we don't get OOM in transformers # when the sentences are long. - sentence_lengths = sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant + sentence_lengths = (sentence_lengths + (sentence_lengths ** 2) * self.quadratic_constant) + self.length_floor values, indices = torch.sort(sentence_lengths) # values,indices dtypes: torch.float,torch.int64 @@ -741,7 +768,7 @@ def set_epoch(self, epoch: int): # now `batch_ids` will be: [0, 0, 0, 0, .., 0, 1, 1, 1, ... 1, 2, ... ], # saying which batch each element of values/indices belongs to. - batch_ids = (torch.cumsum(values, dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) + batch_ids = (torch.cumsum(values.to(dtype=torch.double), dim=0) * (1.0 / self.symbols_per_batch)).to(dtype=torch.int32) batch_boundaries = torch.nonzero(batch_ids[1:] - batch_ids[:-1], as_tuple=True)[0] batch_boundaries.add_(1) @@ -754,6 +781,7 @@ def set_epoch(self, epoch: int): # necessary to randomize the order of these, to avoid returning batches # from shortest to longest sentences. self.batch_indices = torch.randperm(num_batches, generator=g, dtype=torch.int32).tolist() + self._sync_sizes() def __len__(self): From 0d97e689beaa584f4137bcbed7b4455304241d80 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 24 Aug 2021 21:59:41 +0800 Subject: [PATCH 17/26] Version I am running... --- egs/librispeech/ASR/conformer_lm/conformer.py | 2 +- egs/librispeech/ASR/conformer_lm/train.py | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index 1963056cc9..fe0a5eec98 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -21,7 +21,7 @@ def __init__( d_model: int = 256, nhead: int = 4, dim_feedforward: int = 2048, - num_encoder_layers: int = 12, + num_encoder_layers: int = 6, num_decoder_layers: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 66602ea1d3..e8a5c88882 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -317,7 +317,7 @@ def compute_validation_loss( break batch = tuple(x.to(device) for x in batch) - # `batch` is actually a tuple.. we'll unpack it later. + loss = compute_loss(model, batch, is_training=False) num_frames = batch[4].sum() @@ -390,17 +390,23 @@ def train_one_epoch( params.batch_idx_train += 1 batch = tuple(x.to(device) for x in batch) - loss = compute_loss( - model=model, + try: + loss = compute_loss( + model=model, batch=batch, - is_training=True, - ) + is_training=True, + ) + + optimizer.zero_grad() + loss.backward() + # We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total + # gradient scale so this should not matter. + # clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + print(f"Error on batch of shape (N,T) = {batch[0].shape}") + raise e - optimizer.zero_grad() - loss.backward() # We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total - # gradient scale so this should not matter. - # clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() loss_cpu = loss.detach().cpu().item() num_frames_cpu = batch[4].sum().cpu().item() From a7b61100de29f72ec40781879163e22c48702f66 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Aug 2021 11:27:47 +0800 Subject: [PATCH 18/26] Use collate_fn as class. harmless but not necessary without multiple workers --- egs/librispeech/ASR/conformer_lm/dataset.py | 14 +++++++++++++- egs/librispeech/ASR/conformer_lm/train.py | 17 +++++++++-------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index fcf7f39f01..dd3ab8deb2 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -130,7 +130,12 @@ def mask_and_pad(sentence: List[int], # length of masked regions. num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) - assert num_split_points <= sent_len - num_mask + # Somehow this assertion failed, debugging it below. + # assert num_split_points <= sent_len - num_mask + if num_split_points > sent_len - num_mask: + print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}") + num_split_points = sent_len - num_mask + assert isinstance(num_split_points, int) def split_into_subseqs(length: int , num_subseqs: int) -> List[int]: @@ -797,6 +802,13 @@ def __iter__(self): yield self.indices[batch_start:batch_end].tolist() +class CollateFn: + def __init__(self, **kwargs): + self.extra_args = kwargs + + def __call__(self, sentences: List[List[int]]): + return collate_fn(sentences, **self.extra_args) + diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index e8a5c88882..0b7e49db5b 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -532,14 +532,14 @@ def run(rank, world_size, args): train,test = dataset.load_train_test_lm_dataset(params.lm_dataset) - collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=params.bos_sym, - eos_sym=params.eos_sym, - blank_sym=params.blank_sym, - mask_proportion=0.15, - padding_proportion=0.15, - randomize_proportion=0.05, - inv_mask_length=0.25, - unmasked_weight=0.25)) + collate_fn=dataset.CollateFn(bos_sym=params.bos_sym, + eos_sym=params.eos_sym, + blank_sym=params.blank_sym, + mask_proportion=0.15, + padding_proportion=0.15, + randomize_proportion=0.05, + inv_mask_length=0.25, + unmasked_weight=0.25) train_sampler = dataset.LmBatchSampler(train, symbols_per_batch=params.symbols_per_batch, @@ -551,6 +551,7 @@ def run(rank, world_size, args): train_dl = torch.utils.data.DataLoader(train, batch_sampler=train_sampler, collate_fn=collate_fn) + valid_dl = torch.utils.data.DataLoader(test, batch_sampler=test_sampler, collate_fn=collate_fn) From d045831a4f8e73be5da0a197637f4540ab12a64c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 25 Aug 2021 15:54:36 +0800 Subject: [PATCH 19/26] Get dataset to work for empty input sentences; test it --- egs/librispeech/ASR/conformer_lm/dataset.py | 7 +--- .../ASR/conformer_lm/test_dataset_empty.py | 39 +++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_lm/test_dataset_empty.py diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index dd3ab8deb2..4f466a9e1b 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -116,7 +116,7 @@ def mask_and_pad(sentence: List[int], num_pad -= max(0, sent_len + 2 + num_pad - seq_len) if num_mask + num_pad == 0: - num_mask += 1 + num_pad += 1 # num_split_points is the number of times we split the (masked+padded) # region, so the total number of (masking+padding) subsequences will be @@ -131,10 +131,7 @@ def mask_and_pad(sentence: List[int], num_split_points = int(torch.binomial(count=torch.tensor([float(sent_len - num_mask)]), prob=torch.tensor([mask_proportion * inv_mask_length / (1.0 - mask_proportion)])).item()) # Somehow this assertion failed, debugging it below. - # assert num_split_points <= sent_len - num_mask - if num_split_points > sent_len - num_mask: - print(f"Warning about num_split_points: {num_split_points} > {sent_len} - {num_mask}") - num_split_points = sent_len - num_mask + assert num_split_points <= sent_len - num_mask assert isinstance(num_split_points, int) diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py b/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py new file mode 100644 index 0000000000..7e933f07b4 --- /dev/null +++ b/egs/librispeech/ASR/conformer_lm/test_dataset_empty.py @@ -0,0 +1,39 @@ +import k2 +import torch +import _k2 +import dataset +from dataset import LmDataset +import os +from torch import multiprocessing as mp +import torch.distributed as dist + +def local_collate_fn(sentences): + return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=False) + +x = _k2.RaggedInt('[[1]]') # make sure library initialized? + +if __name__ == '__main__': + + mp.set_start_method('spawn') + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12344" + + dist.init_process_group(backend="nccl", group_name="main", + rank=0, world_size=1) + + words = k2.RaggedInt('[[0][1 2]]') + sentences = k2.RaggedInt('[[1][][][][][]]') + + train = LmDataset(sentences, words) + + + sampler = dataset.LmBatchSampler(train, symbols_per_batch=10, world_size=1, rank=0) + + a = iter(sampler) + print(str(next(a))) + + train_dl = torch.utils.data.DataLoader(train, batch_sampler=sampler, + collate_fn=local_collate_fn, + num_workers=0) + x = iter(train_dl) + print(str(next(x))) From ccf7bdec230edb9a833770368b92b74503ae1125 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 28 Aug 2021 21:51:54 +0800 Subject: [PATCH 20/26] Add Foam optimizer; I used this from epoch 3. --- egs/librispeech/ASR/conformer_lm/madam.py | 186 +++++++++++++++++- .../ASR/conformer_lm/test_dataset.py | 39 +++- egs/librispeech/ASR/conformer_lm/train.py | 13 +- 3 files changed, 217 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py index aa605c30b6..07266a63b4 100644 --- a/egs/librispeech/ASR/conformer_lm/madam.py +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -811,6 +811,140 @@ def load_state_dict(self, state_dict): setattr(self, key, value) +class Foam(object): + """ + Implements Foam optimizer. This is a modified version of the Noam optimizer + which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, + but changed to use Madam (see above) instead of Adam as the base optimizer, and then + to change the learning rate schedule and how it is specified. + + + This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + + warm_step: number of warmup steps before the learning rate starts to decrease + (it increases until this point). + max_lrate: The learning rate at its maximum, on step `warm_step` + knee_factor: The multiple of `max_lrate` after which the learning rate will + start to decrease more like 1/x. It increases linearly from 0 to + `warm_step`, then decreases approximately as 1/sqrt(x) from + `warm_step` to `warm_step * knee_factor`, then decreases + approximately as 1/x from `warm_step * knee_factor` onwards. + + min_target_rms: this is a parameter of the Madam optimizer; it represents a floor + on the "target root-mean-square value" that is used when the initialization + of a tensor is zero or below this value. It may be worth optimizing. + Don't worry about tensors with fewer than 2 dimensions when setting this, + these are not subject to our l2 formula. + limit_grad_factor: Another parameter of Madam, you can set this to a finite + value, e.g. 2.0, to activate a mechanism that limits the norms of + larger-than-usual gradients. This seems to cause a slowdown, likely due + to GPU->CPU transfers, and it is disabled by setting it to infinity. + l2_period: mechanism to improve the optimization speed, by only applying the l2 + regularization (which is a complicated formula) every this-many + minibatches. E.g. can set it to 2 or 4. + """ + + def __init__(self, + params, + max_lrate: float = 5.0e-04, + warm_step: int = 25000, + knee_factor: float = 8.0, + min_target_rms: float = 0.05, + limit_grad_factor: float = float('inf'), + l2_period: int = 1) -> None: + """Construct an Noam object.""" + self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period) + self._step = 0 + + self._max_lrate = max_lrate + self._warm_step = warm_step + self._knee_factor = knee_factor + self._rate = 0 + + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + + def rate(self, step=None): + """ + Suppose the step of optimization is 's', i.e. with s = 0, 1, 2... + We define 't = s / warm_step', i.e. t is the step s, normalized so that it + is 1.0 at warm_step. Our formula for the learning rate as a function of + t is: + rate = max_lrate * (t <= 1.0 ? t : + sqrt((2 + alpha) / (1 + t + alpha t^2))) + where alpha is chosen so that the 't' and 'alpha t^2' terms are identical + at t == knee_factor (this means alpha = 1.0/knee_factor). So the + learning rate increases linearly from t=00 to t=1, and decreases + after that. You can see + that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1, + which is why the line and the curve meet at that point. + + On the denominator of that ratio, the "t" term makes it decrease a + bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term + makes it decrease a bit like 1/t for t > warm_step; and the "1" + term makes it decrease a bit slower than 1/sqrt(t) when t is quite + close to 1.0 (so we linger a little, near the maximum learning rate). + + This learning rate schedule ultimately decreases more aggressively + than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we + feel this will work better in conjunction with Madam, is that Madam + keeps the norms of the parameters approximately constant throughout + training; whereas with Noam, if there is no weight decay, these + norms tend to increase as training progresses (although rather + unevenly across different parameter tensors). + As the norms of the parameters increase, the relative changes + in parameters get smaller (the step sizes don't change because + Adam normalizes the gradient magnitudes; they'd get smaller otherwise). + So Noam doesn't have to decrease the learning rate too aggressively + because even with a fixed learning rate, the effective learning rate + would be decreasing (again, this only applies without weight decay). + """ + if step is None: + step = self._step + t = step / self._warm_step # floating point division.. t is the normalized step. + alpha = 1.0 / self._knee_factor + return self._max_lrate * (t if t <= 1.0 else + ((2 + alpha) / (1 + t + alpha * t * t)) ** 0.5) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + } + + def load_state_dict(self, state_dict): + """Load state_dict. This is compatible with reading a Moam state_dict""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + elif key == '_step': + self._step = value + + class TestModel(torch.nn.Module): """Class for testing the Madam optimizer""" @@ -844,9 +978,9 @@ def test_madam(): inf_grad_max_count = 200 if torch.cuda.is_available(): devices_and_l2 = [(torch.device('cuda'), True), - (torch.device('cuda'), False)] - #(torch.device('cpu'), True), - #(torch.device('cpu'), False)] + (torch.device('cuda'), False), + (torch.device('cpu'), True), + (torch.device('cpu'), False)] else: devices_and_l2 = [(torch.device('cpu'), True), (torch.device('cpu'), False)] @@ -922,6 +1056,48 @@ def get_elems_rms(x: Tensor) -> Tensor: print("") +def test_foam(): + print("Testing Foam optimizer") + model = TestModel() + # min_target_rms=0.01 is for testing, so the target equals the initial RMS + # and we can more easily tell whether our update has the desired effect. + optimizer = Foam(model.parameters(), + max_lrate=1.0e-03, warm_step=300, + min_target_rms=0.01, + limit_grad_factor=4.0) + + + def get_elems_rms(x: Tensor) -> Tensor: + return ((x ** 2).sum() / x.numel()).sqrt().item() + + for i in range(1000): + if i % 100 == 0: + rms_values = (get_elems_rms(model.first_layers[0].weight), + get_elems_rms(model.first_layers[2].weight), + get_elems_rms(model.conv1.weight), + get_elems_rms(model.conv2.weight)) + print(f"Iter {i} (Foam): stddevs = {rms_values} ") + B = 4 + T = 20 + x = torch.randn(B, T, 100) + y = model(x) + yderiv = torch.randn_like(y) + if i % 190 <= 3 and i > 0: + yderiv *= 100.0 + if i % 550 == 0 and i > 0: + yderiv *= float('inf') + + y.backward(gradient=yderiv) + optimizer.step() + model.zero_grad() + print("") + + state_dict = optimizer.state_dict() + step = optimizer._step + optimizer._step = 0 + optimizer.load_state_dict(state_dict) + assert optimizer._step == step + def test_to_device(): if not torch.cuda.is_available(): @@ -951,8 +1127,10 @@ def main(): #test_to_device() random.seed(0) torch.random.manual_seed(0) + test_foam() + test_moam() test_madam() - #test_moam() + if __name__ == '__main__': diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset.py b/egs/librispeech/ASR/conformer_lm/test_dataset.py index ed38ed11a2..b82da7899d 100644 --- a/egs/librispeech/ASR/conformer_lm/test_dataset.py +++ b/egs/librispeech/ASR/conformer_lm/test_dataset.py @@ -1,13 +1,34 @@ -import dataset +import k2 import torch +import _k2 +import dataset +import os +from torch import multiprocessing as mp +import torch.distributed as dist + +def local_collate_fn(sentences): + return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True) + +x = _k2.RaggedInt('[[1]]') # make sure library initialized? + +if __name__ == '__main__': + + #mp.set_start_method('spawn') + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12344" + + dist.init_process_group(backend="nccl", group_name="main", + rank=0, world_size=1) + train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') + sampler = dataset.LmBatchSampler(test, symbols_per_batch=5000, world_size=2, rank=0) + print("len(sampler) = ", len(sampler)) -train,test = dataset.load_train_test_lm_dataset('../data/lm_training_5000/lm_data.pt') -sampler = dataset.LmBatchSampler(test, symbols_per_batch=1000, world_size=2, rank=0) -a = iter(sampler) -print(str(next(a))) + a = iter(sampler) + print(str(next(a))) -collate_fn=(lambda x:dataset.collate_fn(x, bos_sym=1, eos_sym=1, blank_sym=0, debug=True)) -train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, collate_fn=collate_fn) -x = iter(train_dl) -print(str(next(x))) + train_dl = torch.utils.data.DataLoader(test, batch_sampler=sampler, + collate_fn=local_collate_fn, + num_workers=2) + x = iter(train_dl) + print(str(next(x))) diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 0b7e49db5b..5ca267147e 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -35,7 +35,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from madam import Moam +from madam import Foam from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -138,7 +138,7 @@ def get_params() -> AttributeDict: "blank_sym": 0, "bos_sym": 1, "eos_sym": 1, - "start_epoch": 0, + "start_epoch": 3, "num_epochs": 20, "num_valid_batches": 200, "symbols_per_batch": 5000, @@ -155,8 +155,7 @@ def get_params() -> AttributeDict: "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "lr_factor": 2.0, - "warm_step": 20000, + "max_lrate": 5.0e-04 } ) @@ -520,11 +519,9 @@ def run(rank, world_size, args): if world_size > 1: model = DDP(model, device_ids=[rank]) - optimizer = Moam( + optimizer = Foam( model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, + max_lrate=params.max_lrate ) if checkpoints: From 573e0582d8319c5b23044fe3c5f6ebef8c7f8557 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 30 Aug 2021 14:10:21 +0800 Subject: [PATCH 21/26] Run in exp_2, with foam from start, knee_factor=5.0, initial_lrate=2e-04. --- egs/librispeech/ASR/conformer_lm/madam.py | 2 +- egs/librispeech/ASR/conformer_lm/train.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py index 07266a63b4..36716efecc 100644 --- a/egs/librispeech/ASR/conformer_lm/madam.py +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -852,7 +852,7 @@ def __init__(self, params, max_lrate: float = 5.0e-04, warm_step: int = 25000, - knee_factor: float = 8.0, + knee_factor: float = 5.0, min_target_rms: float = 0.05, limit_grad_factor: float = float('inf'), l2_period: int = 1) -> None: diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 5ca267147e..4c0219eb18 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -132,13 +132,13 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_lm/exp_1"), + "exp_dir": Path("conformer_lm/exp_2"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, "bos_sym": 1, "eos_sym": 1, - "start_epoch": 3, + "start_epoch": 0, "num_epochs": 20, "num_valid_batches": 200, "symbols_per_batch": 5000, @@ -155,7 +155,7 @@ def get_params() -> AttributeDict: "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "max_lrate": 5.0e-04 + "max_lrate": 2.0e-04 # was 5.0e-04, then from start_epoch=9 used max_lrate=2.0e-04, then from start_epoch=11 used 1.0e-04. } ) From d313c27c1474a96b44cca175e889b59df2cb7a06 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 7 Sep 2021 20:58:00 +0800 Subject: [PATCH 22/26] Change configuration again.. not great performance. --- egs/librispeech/ASR/conformer_lm/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 4c0219eb18..d70d066741 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -132,7 +132,8 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_lm/exp_2"), + # exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate. + "exp_dir": Path("conformer_lm/exp_3"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, @@ -155,7 +156,7 @@ def get_params() -> AttributeDict: "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "max_lrate": 2.0e-04 # was 5.0e-04, then from start_epoch=9 used max_lrate=2.0e-04, then from start_epoch=11 used 1.0e-04. + "max_lrate": 5.0e-04 } ) From 56a88badd11271d2a7b98fe22d3ad0421e94dd84 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 8 Sep 2021 13:59:50 +0800 Subject: [PATCH 23/26] Move to Gloam optimizer, exponential lrate --- egs/librispeech/ASR/conformer_lm/madam.py | 152 +++++++++++++++++++++- egs/librispeech/ASR/conformer_lm/train.py | 15 ++- 2 files changed, 158 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py index 36716efecc..bc81683305 100644 --- a/egs/librispeech/ASR/conformer_lm/madam.py +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -813,11 +813,6 @@ def load_state_dict(self, state_dict): class Foam(object): """ - Implements Foam optimizer. This is a modified version of the Noam optimizer - which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, - but changed to use Madam (see above) instead of Adam as the base optimizer, and then - to change the learning rate schedule and how it is specified. - This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py @@ -946,6 +941,153 @@ def load_state_dict(self, state_dict): +class Gloam(object): + """ + Implements Gloam optimizer. This is a modified version of the Noam optimizer + which was proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf, + but changed to use Madam (see above) instead of Adam as the base optimizer, and then + to change the learning rate schedule and how it is specified. We have + a warm-up stage, but after it gets to `max_lrate` it stays constant for the + rest of the 1st epoch, and after that, only changes on epoch boundaries. + + CAUTION: you have to call set_epoch() every epoch, to set the epoch. If you don't do this, + this won't work! + + + This code was modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + + warm_step: number of warmup steps before the learning rate starts to decrease + (it increases until this point). + max_lrate: The learning rate at its maximum, on step `warm_step` + first_decrease_epoch: The epoch number on which to start decreasing the + learning rate. + decay_per_epoch: + min_target_rms: this is a parameter of the Madam optimizer; it represents a floor + on the "target root-mean-square value" that is used when the initialization + of a tensor is zero or below this value. It may be worth optimizing. + Don't worry about tensors with fewer than 2 dimensions when setting this, + these are not subject to our l2 formula. + limit_grad_factor: Another parameter of Madam, you can set this to a finite + value, e.g. 2.0, to activate a mechanism that limits the norms of + larger-than-usual gradients. This seems to cause a slowdown, likely due + to GPU->CPU transfers, and it is disabled by setting it to infinity. + l2_period: mechanism to improve the optimization speed, by only applying the l2 + regularization (which is a complicated formula) every this-many + minibatches. E.g. can set it to 2 or 4. + """ + + def __init__(self, + params, + max_lrate: float = 5.0e-04, + warm_step: int = 25000, + first_decrease_epoch: int = 1, + decay_per_epoch: float = 0.85, + min_target_rms: float = 0.05, + limit_grad_factor: float = float('inf'), + l2_period: int = 1) -> None: + """Construct an Noam object.""" + self.optimizer = Madam(params, lr=0, betas=(0.9, 0.98), eps=1e-9, + min_target_rms=min_target_rms, + limit_grad_factor=limit_grad_factor, + l2_period=l2_period) + self._step = 0 + + self._max_lrate = max_lrate + self._warm_step = warm_step + self._first_decrease_epoch = first_decrease_epoch + self._decay_per_epoch = decay_per_epoch + self._rate = 0 + self._epoch = 0 + + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def set_epoch(self, epoch: int): + self._epoch = epoch + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + + def rate(self, step=None): + """ + Suppose the step of optimization is 's', i.e. with s = 0, 1, 2... + We define 't = s / warm_step', i.e. t is the step s, normalized so that it + is 1.0 at warm_step. Our formula for the learning rate as a function of + t is: + rate = max_lrate * (t <= 1.0 ? t : + sqrt((2 + alpha) / (1 + t + alpha t^2))) + where alpha is chosen so that the 't' and 'alpha t^2' terms are identical + at t == knee_factor (this means alpha = 1.0/knee_factor). So the + learning rate increases linearly from t=00 to t=1, and decreases + after that. You can see + that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1, + which is why the line and the curve meet at that point. + + On the denominator of that ratio, the "t" term makes it decrease a + bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term + makes it decrease a bit like 1/t for t > warm_step; and the "1" + term makes it decrease a bit slower than 1/sqrt(t) when t is quite + close to 1.0 (so we linger a little, near the maximum learning rate). + + This learning rate schedule ultimately decreases more aggressively + than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we + feel this will work better in conjunction with Madam, is that Madam + keeps the norms of the parameters approximately constant throughout + training; whereas with Noam, if there is no weight decay, these + norms tend to increase as training progresses (although rather + unevenly across different parameter tensors). + As the norms of the parameters increase, the relative changes + in parameters get smaller (the step sizes don't change because + Adam normalizes the gradient magnitudes; they'd get smaller otherwise). + So Noam doesn't have to decrease the learning rate too aggressively + because even with a fixed learning rate, the effective learning rate + would be decreasing (again, this only applies without weight decay). + """ + if step is None: + step = self._step + t = step / self._warm_step # floating point division.. t is the normalized step. + base_rate = self._max_lrate * (t if t <= 1.0 else 1.0) + epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decrease_epoch) + return base_rate * epoch_rate + + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "_epoch": self._epoch, + } + + def load_state_dict(self, state_dict): + """Load state_dict. This is compatible with reading a Moam state_dict""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + elif key == '_step': + self._step = value + elif key == '_epoch': + self._epoch = value + + + class TestModel(torch.nn.Module): """Class for testing the Madam optimizer""" def __init__(self): diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index d70d066741..dd35a2d770 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -35,7 +35,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter -from madam import Foam +from madam import Gloam from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl @@ -133,7 +133,8 @@ def get_params() -> AttributeDict: params = AttributeDict( { # exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate. - "exp_dir": Path("conformer_lm/exp_3"), + # exp_4, vs. exp_3, is using the Gloam optimizer with + "exp_dir": Path("conformer_lm/exp_4"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, @@ -520,9 +521,13 @@ def run(rank, world_size, args): if world_size > 1: model = DDP(model, device_ids=[rank]) - optimizer = Foam( + # Caution: don't forget to do optimizer.set_epoch() with Gloam! + # Don't remove this warning! + optimizer = Gloam( model.parameters(), - max_lrate=params.max_lrate + max_lrate=params.max_lrate, + first_decrease_epoch=2, + decay_per_epoch=0.85 ) if checkpoints: @@ -556,6 +561,8 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_sampler.set_epoch(epoch) + optimizer.set_epoch(epoch) # Caution: this is specific to the Gloam + # optimizer. cur_lr = optimizer._rate if tb_writer is not None: From d0e5b9b8a5361fae953d79591ed073a3e3be63b7 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 9 Sep 2021 14:08:19 +0800 Subject: [PATCH 24/26] Change to exp_5, 1/sqrt(t) component. --- egs/librispeech/ASR/conformer_lm/madam.py | 33 ++++------------------- egs/librispeech/ASR/conformer_lm/train.py | 6 +++-- 2 files changed, 9 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/madam.py b/egs/librispeech/ASR/conformer_lm/madam.py index bc81683305..b7bb77ec9d 100644 --- a/egs/librispeech/ASR/conformer_lm/madam.py +++ b/egs/librispeech/ASR/conformer_lm/madam.py @@ -1028,39 +1028,16 @@ def rate(self, step=None): We define 't = s / warm_step', i.e. t is the step s, normalized so that it is 1.0 at warm_step. Our formula for the learning rate as a function of t is: - rate = max_lrate * (t <= 1.0 ? t : - sqrt((2 + alpha) / (1 + t + alpha t^2))) - where alpha is chosen so that the 't' and 'alpha t^2' terms are identical - at t == knee_factor (this means alpha = 1.0/knee_factor). So the - learning rate increases linearly from t=00 to t=1, and decreases - after that. You can see - that sqrt((2 + alpha) / (1 + t + alpha t^2))) is 1.0 when t == 1, - which is why the line and the curve meet at that point. - - On the denominator of that ratio, the "t" term makes it decrease a - bit like 1/sqrt(t) in 1 <= t <= warm_step; the "alpha t^2" term - makes it decrease a bit like 1/t for t > warm_step; and the "1" - term makes it decrease a bit slower than 1/sqrt(t) when t is quite - close to 1.0 (so we linger a little, near the maximum learning rate). + base_rate = max_lrate * (t <= 1.0 ? t : t ** -0.5) + epoch_rate = [starts at 1.0 but from first_decrease_epoch, start decreasing it + by a factor of decay_per_epoch] + rate = base_rate * epoch_rate - This learning rate schedule ultimately decreases more aggressively - than Noam, i.e. as 1 / t instead of 1 / sqrt(t). The reason we - feel this will work better in conjunction with Madam, is that Madam - keeps the norms of the parameters approximately constant throughout - training; whereas with Noam, if there is no weight decay, these - norms tend to increase as training progresses (although rather - unevenly across different parameter tensors). - As the norms of the parameters increase, the relative changes - in parameters get smaller (the step sizes don't change because - Adam normalizes the gradient magnitudes; they'd get smaller otherwise). - So Noam doesn't have to decrease the learning rate too aggressively - because even with a fixed learning rate, the effective learning rate - would be decreasing (again, this only applies without weight decay). """ if step is None: step = self._step t = step / self._warm_step # floating point division.. t is the normalized step. - base_rate = self._max_lrate * (t if t <= 1.0 else 1.0) + base_rate = self._max_lrate * (t if t <= 1.0 else t ** -0.5) epoch_rate = self._decay_per_epoch ** max(0, self._epoch + 1 - self._first_decrease_epoch) return base_rate * epoch_rate diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index dd35a2d770..04c3f8ccde 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -134,7 +134,9 @@ def get_params() -> AttributeDict: { # exp_3, vs. exp_2, is using 5e-04 not 2d-04 as max learning rate. # exp_4, vs. exp_3, is using the Gloam optimizer with - "exp_dir": Path("conformer_lm/exp_4"), + # in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor + # as well as the exponential part. + "exp_dir": Path("conformer_lm/exp_5"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, @@ -526,7 +528,7 @@ def run(rank, world_size, args): optimizer = Gloam( model.parameters(), max_lrate=params.max_lrate, - first_decrease_epoch=2, + first_decrease_epoch=1, decay_per_epoch=0.85 ) From 3ce1de337dd7154967369bbb6ae61a7abef6586a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 13 Sep 2021 20:57:02 +0800 Subject: [PATCH 25/26] UPdates for new k2 version; change LR decay from 0.85 to 0.9 --- egs/librispeech/ASR/conformer_lm/dataset.py | 46 ++++++++++--------- .../ASR/conformer_lm/test_dataset.py | 2 - egs/librispeech/ASR/conformer_lm/train.py | 7 +-- .../ASR/local/prepare_lm_training_data.py | 4 +- egs/librispeech/ASR/prepare.sh | 4 +- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/egs/librispeech/ASR/conformer_lm/dataset.py b/egs/librispeech/ASR/conformer_lm/dataset.py index 4f466a9e1b..6c28c21cad 100644 --- a/egs/librispeech/ASR/conformer_lm/dataset.py +++ b/egs/librispeech/ASR/conformer_lm/dataset.py @@ -14,8 +14,8 @@ class LmDataset(torch.utils.data.Dataset): Torch dataset for language modeling data. This is a map-style dataset. The indices are integers. """ - def __init__(self, sentences: k2.RaggedInt, - words: k2.RaggedInt): + def __init__(self, sentences: k2.RaggedTensor, + words: k2.RaggedTensor): super(LmDataset, self).__init__() self.sentences = sentences self.words = words @@ -30,12 +30,17 @@ def __getitem__(self, i: int): Return the i'th sentence, as a list of ints (representing BPE pieces, without bos or eos symbols). """ - # It would be nicer if we could just return self.sentences[i].tolist(), but - # for now that operator on k2.RaggedInt is not implemented. - row_splits = self.sentences.row_splits(1) + # in future will just do: + #return self.words[self.sentences[i]].tolist() + + # It would be nicer if we could just return self.sentences[i].tolist(), + # but for now that operator on k2.RaggedInt does not support when the + # ragged has only 2 axes. + row_splits = self.sentences.shape.row_splits(1) (begin, end) = row_splits[i:i+2].tolist() - sentence = self.sentences.values()[begin:end] - return k2.index(self.words, sentence).values().tolist() + sentence = self.sentences.data[begin:end] + sentence, _ = self.words.index(sentence, axis=0, need_value_indexes=False) + return sentence.data.tolist() def load_train_test_lm_dataset(archive_fn: Union[str,Path], @@ -45,22 +50,21 @@ def load_train_test_lm_dataset(archive_fn: Union[str,Path], """ d = torch.load(archive_fn) - words = d['words'] # a k2.RaggedInt with 2 axes, maps from word-ids to sequences of BPE pieces - sentences = d['data'] # a k2.RaggedInt + words = d['words'] # a k2.RaggedTensor with 2 axes, maps from word-ids to sequences of BPE pieces + sentences = d['data'] # a k2.RaggedTensor with torch.random.fork_rng(devices=[]): g = torch.manual_seed(0) num_sentences = sentences.tot_size(0) # probably the generator (g) argument to torch.randperm below is not necessary. sentence_perm = torch.randperm(num_sentences, generator=g, dtype=torch.int32) - sentences = k2.index(sentences, sentence_perm) + sentences, _ = sentences.index(sentence_perm, axis=0, need_value_indexes=False) num_test_sentences = int(num_sentences * test_proportion) axis=0 - train_sents = _k2.ragged_int_arange(sentences, axis, - num_test_sentences, num_sentences) - test_sents = _k2.ragged_int_arange(sentences, axis, 0, num_test_sentences) + train_sents = sentences.arange(axis, num_test_sentences, num_sentences) + test_sents = sentences.arange(axis, 0, num_test_sentences) return LmDataset(train_sents, words), LmDataset(test_sents, words) @@ -683,27 +687,25 @@ class does not retain a reference to the LmDataset. # sampler is reponsible for (all of them, in the non-distributed case). data_indexes = torch.arange(self.rank, len(dataset), self.world_size, dtype=torch.int32) # dtype=torch.int32 - word_row_splits = dataset.words.row_splits(1) # dtype=torch.int32 + word_row_splits = dataset.words.shape.row_splits(1) # dtype=torch.int32 word_lengths = word_row_splits[1:] - word_row_splits[:-1] # dtype=torch.int32 # the sentences this sampler is responsible for, as sequences of words. # It's a ragged tensor of int32 - sentences = k2.index(dataset.sentences, data_indexes) + sentences, _ = dataset.sentences.index(data_indexes, axis=0) - # sentence_lengths is a k2.RaggedInt like `sentences`, but with the words replaced + # sentence_lengths is a k2.RaggedTensor like `sentences`, but with the words replaced # with their respective lengths, in BPE pieces. - sentence_lengths = k2.index(word_lengths, sentences) + sentence_lengths = k2.ragged.index(word_lengths, sentences) del sentences # save memory - assert isinstance(sentence_lengths, k2.RaggedInt) + assert isinstance(sentence_lengths, k2.RaggedTensor) # convert to float so sum_per_sublist() will work (TODO: sum_per_sublist() will eventually # support int32.) - sentence_lengths = k2.RaggedFloat(sentence_lengths.shape(), - sentence_lengths.values().to(torch.float32)) - assert isinstance(sentence_lengths, k2.RaggedFloat) + sentence_lengths = sentence_lengths.to(dtype=torch.float32) # Convert into a simple tensor of float by adding lengths of words. - sentence_lengths = k2.ragged.sum_per_sublist(sentence_lengths) + sentence_lengths = sentence_lengths.sum() assert isinstance(sentence_lengths, torch.Tensor) assert sentence_lengths.dtype == torch.float32 diff --git a/egs/librispeech/ASR/conformer_lm/test_dataset.py b/egs/librispeech/ASR/conformer_lm/test_dataset.py index b82da7899d..4cadaa9392 100644 --- a/egs/librispeech/ASR/conformer_lm/test_dataset.py +++ b/egs/librispeech/ASR/conformer_lm/test_dataset.py @@ -9,8 +9,6 @@ def local_collate_fn(sentences): return dataset.collate_fn(sentences, bos_sym=1, eos_sym=1, blank_sym=0, debug=True) -x = _k2.RaggedInt('[[1]]') # make sure library initialized? - if __name__ == '__main__': #mp.set_start_method('spawn') diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 04c3f8ccde..2d1c1a4c35 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -136,13 +136,14 @@ def get_params() -> AttributeDict: # exp_4, vs. exp_3, is using the Gloam optimizer with # in exp_5, vs. exp_4, we change Gloam to have a 1/sqrt(t) factor # as well as the exponential part. - "exp_dir": Path("conformer_lm/exp_5"), + # exp_6, we change the decay from 0.85 to 0.9. + "exp_dir": Path("conformer_lm/exp_6"), "lm_dataset": Path("data/lm_training_5000/lm_data.pt"), "num_tokens": 5000, "blank_sym": 0, "bos_sym": 1, "eos_sym": 1, - "start_epoch": 0, + "start_epoch": 2, "num_epochs": 20, "num_valid_batches": 200, "symbols_per_batch": 5000, @@ -529,7 +530,7 @@ def run(rank, world_size, args): model.parameters(), max_lrate=params.max_lrate, first_decrease_epoch=1, - decay_per_epoch=0.85 + decay_per_epoch=0.9 ) if checkpoints: diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index b6e0931f40..a836bb0172 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -82,8 +82,8 @@ def main(): sentences.append([ word2index[w] for w in line_words]) output = dict() - output['words' ] = k2.ragged.create_ragged2(words2bpe) - output['data'] = k2.ragged.create_ragged2(sentences) + output['words' ] = k2.ragged.RaggedTensor(words2bpe) + output['data'] = k2.ragged.RaggedTensor(sentences) torch.save(output, args.lm_archive) print(f"Saved to {args.lm_archive}") diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 94c408c6e4..0e7dc510fb 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -3,7 +3,7 @@ set -eou pipefail nj=15 -stage=-1 +stage=9 stop_stage=100 # We assume dl_dir (download dir) contains the following @@ -195,7 +195,7 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} - lm_dir=lm_dir=data/lm_training_${vocab_size} + lm_dir=data/lm_training_${vocab_size} mkdir -p $lm_dir log "Stage 9: creating $lm_dir/lm_data.pt" ./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt From 8e650a584134e9cd42216427ac1f2f2a0ae45b74 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 27 Sep 2021 12:19:01 +0800 Subject: [PATCH 26/26] Update egs/librispeech/ASR/conformer_lm/conformer.py Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/conformer_lm/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_lm/conformer.py b/egs/librispeech/ASR/conformer_lm/conformer.py index fe0a5eec98..d08b64ea37 100644 --- a/egs/librispeech/ASR/conformer_lm/conformer.py +++ b/egs/librispeech/ASR/conformer_lm/conformer.py @@ -112,7 +112,7 @@ def forward( Returns: Returns (memory, pos_emb), where: - `memory` is a Tensor containing the encoded data; it is of shape (N, T, C) + `memory` is a Tensor containing the encoded data; it is of shape (T, N, C) where C is the embedding_dim. `pos_emb` is a Tensor containing the relative positional encoding, of shape (1, 2*T-1, C)