Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conformer lm #54

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open

Conformer lm #54

wants to merge 29 commits into from

Conversation

danpovey
Copy link
Collaborator

No description provided.

@csukuangfj
Copy link
Collaborator

We will add decoding script to it.

@danpovey
Copy link
Collaborator Author

See also https://tensorboard.dev/experiment/unF4gSyjRjua2DSKgb3BMg/
and /ceph-dan/icefall/egs/librispeech/ASR/conformer_m/exp_6

@csukuangfj csukuangfj self-assigned this Sep 25, 2021
bos or eos symbols).
"""
# in future will just do:
#return self.words[self.sentences[i]].tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supported by k2-fsa/k2#833

@danpovey
Copy link
Collaborator Author

danpovey commented Sep 25, 2021 via email

cnn_module_kernel,
)
self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers,
norm=nn.LayerNorm(d_model))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using pre-normalization here.
You have placed a layer norm at the end of the encoder layer, see

You are using an extra layer norm here, which means you are doing

x = layernorm(layernorm(x))

See

if self.norm is not None:
x = self.norm(x)


I just realized that it is even worse.

You are using a layer norm at both ends of an encoder layer but
encoder layers are connected end-by-end, which means the output of the layer norm
from the previous encoder layer is used as the input of the layer norm of the next encoder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "even worse" part is not quite right, because there are bypass connections, so there are paths involving residuals where the input is used without layer norm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. but yes, I agree that the LayerNorm at the end of the conformer encoder is redundant. I have since stopped using that. But in this particular case, it would take a long time to retrain the model if we were to fix it, so I'd say leave it as-is for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I'd say leave it as-is for now.

Yes, I agree. After finishing the decoding script, I would recommend removing it and
and re-run the whole pipeline.

lm_dir=data/lm_training_${vocab_size}
mkdir -p $lm_dir
log "Stage 9: creating $lm_dir/lm_data.pt"
./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
./local/prepare_lm_training_data.py data/lang_bpe_${vocab_size}/bpe.model download/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt
./local/prepare_lm_training_data.py $lang_dir/bpe.model $dl_dir/lm/librispeech-lm-norm.txt $lm_dir/lm_data.pt

cnn_module_kernel,
)
self.encoder = MaskedLmConformerEncoder(encoder_layer, num_encoder_layers,
norm=nn.LayerNorm(d_model))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I'd say leave it as-is for now.

Yes, I agree. After finishing the decoding script, I would recommend removing it and
and re-run the whole pipeline.

# Calling this on all copies of a DDP setup will sync the sizes so that
# all copies have the exact same number of batches. I think
# this needs to be called with the GPU device, not sure if it would
# work otherwise.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will not work for CPU devices as DDP requires GPU devices.


def _sync_sizes(self, device: torch.device = torch.device('cuda')):
# Calling this on all copies of a DDP setup will sync the sizes so that
# all copies have the exact same number of batches. I think
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we mention that without doing this, the training process
will hang indefinitely at the end?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mm, sure...

csukuangfj added a commit to csukuangfj/icefall that referenced this pull request Nov 1, 2021
csukuangfj added a commit to csukuangfj/icefall that referenced this pull request Nov 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants