Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Try to use multiple datasets with pruned transducer loss #245

Closed
wants to merge 10 commits into from

Conversation

csukuangfj
Copy link
Collaborator

It also refactors the decoder and joiner to remove the extra nn.Linear() layer.

Will try #229 with this PR.

@csukuangfj csukuangfj changed the title WIP: Try to use multiple dataset with pruned transducer loss WIP: Try to use multiple datasets with pruned transducer loss Mar 9, 2022
@pkufool
Copy link
Collaborator

pkufool commented Mar 10, 2022

If removing the extra nn.Linear(), the encoder_out_dim should be greater then vocab-size, otherwise it will cause an error in rnnt_loss_simple/smoothed. It OK for this PR (i.e. the encoder_out_dim=512,vocab-size=500), I think we should add some documents to clarify that.

[edit] I mean the extra nn.Linear() in decoder. The extra nn.Linear() in joiner is to reduce parameters (If the vocab-size is large), can be removed.

@csukuangfj
Copy link
Collaborator Author

It OK for this PR

Yes, for the librispeech recipe, we are using vocab size 500, so the nn.Linear() layers in decoder and joiner are not
necessary. We can keep them for aishell.

I think we should add some documents to clarify that.

I don't know the underlying reason. Maybe we should document it in k2

@danpovey
Copy link
Collaborator

Guys,
I noticed when experimenting with systems with d_model=256, that it actually performs poorly if the joiner input
has dim=256, it's necessary to set it to dim=512 (I did not try larger).
So I know I previously said we should just let that dim be the same as the d_model, but in fact I'm not so sure about this now.

@danpovey
Copy link
Collaborator

... if we're using the pruned-loss training, it might be worthwhile trying with encoder-output-dim = 1024.

@csukuangfj
Copy link
Collaborator Author

the encoder_out_dim should be greater then vocab-size, otherwise it will cause an error in rnnt_loss_simple/smoothed

Can it be fixed on the k2 side so that we can use a larger encoder_out_dim without adding extra nn.Linear() layers?

@danpovey
Copy link
Collaborator

Guys, I'm not so enthusiastic about avoiding the extra linear layer if it requires that embedding_dim >= vocab_size.
The problem is that it makes the joiner kind of meaningless, it is not able to learn a nontrivial function because we are forcing its input to have a particular meaning.

@csukuangfj
Copy link
Collaborator Author

... if we're using the pruned-loss training, it might be worthwhile trying with encoder-output-dim = 1024.

In that case, more than half of the encoder outputs are not used in k2.get_rnnt_logprobs() when vocab size is 500.
https://github.com/k2-fsa/k2/blob/master/k2/python/k2/rnnt_loss.py#L155

    px_am = torch.gather(
        am.unsqueeze(1).expand(B, S, T, C),
        dim=3,
        index=symbols.reshape(B, S, 1, 1).expand(B, S, T, 1),
    ).squeeze(
        -1
    )  # [B][S][T]

You can see that only the left half of am is used since entries in symbols are less than 500 when C is 1024.

@danpovey
Copy link
Collaborator

@csukuangfj I think it is a mistake to be confusing or identifying the encoder_output_dim with the vocabulary size; I think there should be a projection from one to the other. But actually, in my opinion, it might make more sense to conceptualize the encoder_output_dim as the "hidden dim" of the joiner, i.e. where the nonlinearity (tanh) takes place.

That is: we'd change it so the network would have output of dim==attention-dim (i.e. no linear projection at the output), and we could project that in different ways in the Transducer model:
(i) project from d_model to vocab_size for use in simple/pruned loss; we could perhaps have a version of the Decoder that projects directly to vocab_size for this purpose.
(ii) have a separate projection from d_model to joiner_dim (which we conceptualize as a hidden-dim of the joiner), and have a separate version of the decoder, sharing the embedding, that projects to joiner_dim. joiner_dim could be a bit larger, like 1024. i.e. we'd have decoder and simple_decoder, where decoder output-dim == joiner_dim and simple_decoder output-dim == vocab_size.

1 similar comment
@danpovey
Copy link
Collaborator

@csukuangfj I think it is a mistake to be confusing or identifying the encoder_output_dim with the vocabulary size; I think there should be a projection from one to the other. But actually, in my opinion, it might make more sense to conceptualize the encoder_output_dim as the "hidden dim" of the joiner, i.e. where the nonlinearity (tanh) takes place.

That is: we'd change it so the network would have output of dim==attention-dim (i.e. no linear projection at the output), and we could project that in different ways in the Transducer model:
(i) project from d_model to vocab_size for use in simple/pruned loss; we could perhaps have a version of the Decoder that projects directly to vocab_size for this purpose.
(ii) have a separate projection from d_model to joiner_dim (which we conceptualize as a hidden-dim of the joiner), and have a separate version of the decoder, sharing the embedding, that projects to joiner_dim. joiner_dim could be a bit larger, like 1024. i.e. we'd have decoder and simple_decoder, where decoder output-dim == joiner_dim and simple_decoder output-dim == vocab_size.

@csukuangfj
Copy link
Collaborator Author

Thanks! I see. Will a make a change.

mask = make_pad_mask(lengths)
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)

x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last nn.Linear() from the transformer model is removed.

if self.normalize_before:
x = self.after_norm(x)

x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last nn.Linear() from the conformer model is removed.

src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last nn.LayerNorm of the conformer encoder layer is also removed.
Otherwise, when normalize_before is True,
(1) The output of the LayerNorm of the i-th encoder layer is fed into the input of the LayerNorm of the (i+1)-th encoder layer.

(2) The output of the LayerNorm of the laster encoder layer is fed into the input of the LayerNorm in the conformer model

"subsampling_factor": 4,
"attention_dim": 512,
"decoder_embedding_dim": 512,
"joiner_dim": 1024, # input dim of the joiner
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joiner dim is set to 1024.

boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

simple_decoder_out = simple_decoder_linear(decoder_out)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two nn.Linear() layers are used to transform the encoder output and decoder output for computing the simple loss.

am=simple_encoder_out, lm=simple_decoder_out, ranges=ranges
)

am_pruned = encoder_linear(am_pruned)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two nn.Linear() layers are used to transform the pruned outputs to the dimension of joiner_dim.

@danpovey
Copy link
Collaborator

Cool. We should experiment whether joiner_dim=512 or joiner_dim=1024 works better.... e.g. with a few epochs. I imagine 1024 will be an easy win, but we'll see.

@danpovey
Copy link
Collaborator

danpovey commented May 4, 2022

Why is this not merged yet? Was it worse? [oh, I see, this is not the latest pruned_transducer_stateless2 setup...]

@csukuangfj
Copy link
Collaborator Author

Why is this not merged yet? Was it worse? [oh, I see, this is not the latest pruned_transducer_stateless2 setup...]

Closing via #312

@csukuangfj csukuangfj closed this May 4, 2022
@csukuangfj csukuangfj deleted the pruned-multi-dataset branch May 4, 2022 08:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants