Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems when converting fairseq model to hf format #28174

Closed
2 of 4 tasks
upskyy opened this issue Dec 21, 2023 · 6 comments · May be fixed by #28250
Closed
2 of 4 tasks

Problems when converting fairseq model to hf format #28174

upskyy opened this issue Dec 21, 2023 · 6 comments · May be fixed by #28250

Comments

@upskyy
Copy link

upskyy commented Dec 21, 2023

System Info

  • transformers version: 4.37.0.dev0
  • Platform: Linux-5.15.0-88-generic-x86_64-with-glibc2.35
  • Python version: 3.10.8
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.3.2
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 1.13.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@sanchit-gandhi

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Thanks for releasing this awesome repo.

Issue 1

I am converting the fairseq checkpoint to huggingface format (wav2vec2_conformer). Converting is no problem, but the results are different.
I did some debugging and found something different from the fairseq implementation.
In fairseq, if the convolution subsampling dimension and encoder dimension are the same, nn.Linear is not used, but hf is used unconditionally, so there is a problem of using random weights.

fairseq

https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py#L324-L328

self.post_extract_proj = (
    nn.Linear(self.embed, cfg.encoder_embed_dim)
    if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input
    else None
)

huggingface

https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py#L536

class Wav2Vec2ConformerFeatureProjection(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
        self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)  # <-- HERE
        self.dropout = nn.Dropout(config.feat_proj_dropout)

    def forward(self, hidden_states):
        # non-projected hidden states are needed for quantization
        norm_hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.projection(norm_hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states, norm_hidden_states

I think this is right.

class Wav2Vec2ConformerFeatureProjection(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
		if config.conv_dim[-1] != config.hidden_size:
            	self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
        self.dropout = nn.Dropout(config.feat_proj_dropout)

Issue 2

Also, fairseq performs layer norm before entering the conformer encoder, but huggingface is supposed to perform layer norm after the conformer encoder without any options. Can this be handled as an option? I think the results change because of this.

fairseq

https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py#L1230-L1231

def extract_features(self, x, padding_mask=None, tgt_layer=None):
    if padding_mask is not None:
        x = index_put(x, padding_mask, 0)

    # B x T x C -> T x B x C
    x = x.transpose(0, 1)

    # B X T X C here
    position_emb = None
    if self.pos_enc_type == "rel_pos":
        position_emb = self.embed_positions(x)

    if not self.layer_norm_first:  # <-- HERE
        x = self.layer_norm(x)

    x = F.dropout(x, p=self.dropout, training=self.training)

    layer_results = []
    r = None
    for i, layer in enumerate(self.layers):
        dropout_probability = np.random.random()
        if not self.training or (dropout_probability > self.layerdrop):
            x, z = layer(
                x,
                self_attn_padding_mask=padding_mask,
                need_weights=False,
                position_emb=position_emb,
            )
            if tgt_layer is not None:
                layer_results.append((x, z))
        if i == tgt_layer:
            r = x
            break

huggingface

https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py#L929

Expected behavior

How do you think about this problem?
If modifications are possible, I can proceed with the PR by including a converting script including the fairseq extension.

@amyeroberts
Copy link
Collaborator

cc @ylacombe as well for reference

@ylacombe
Copy link
Contributor

Hey @upskyy, thanks for opening this issue, this is very clear and in line with #28165 which converts a model from seamless communication and fairseq2.

We are supposed to have integration tests making sure that the two implementations have the same results, but they may very well be outdated or specific to certain wav2vec model.

Regarding your issues, could you provide the model that you are testing and a script that shows how to replicate the fact that results are different ?

Regarding issue 1, we'd have to make sure that the case in which self.projection = None actually happens with the wav2vec2 checkpoints proposed. If that never happens, there's no need to add some unnecessary complexity!

Regarding issue 2, #28165 adds skip_encoder_layer_norm, a parameter to simply skip this layer norm. However, the name layer_norm_first implies that it might be computed somewhere else. In my case, skip_encoder_layer_norm is enough but it might not generalize to your checkpoints.

Thanks again!

@upskyy
Copy link
Author

upskyy commented Dec 22, 2023

@ylacombe Thanks for your reply.

So should I just wait for #28165 PR to merge?
In the actual fairseq learning process, the projection is used only when the last dimension of convolution subsampling and the dimension of the conformer encoder block are different.
For example, if both are 512 dimension, the projection weight is not in the fairseq checkpoint.
So, no error occurs when converted to huggingface format, but when I inference the huggingface model, random weight projection is used. Then the result will be ruined.

Thanks : )

@ylacombe
Copy link
Contributor

Hey @upskyy #28165 won't solve your issue 1 for sure, and might solve 2 as well.
Could you open a PR with your proposed solution ? And also give me a pointer to a checkpoint in which there are no projection weight ?
Many thanks!

@upskyy
Copy link
Author

upskyy commented Jan 5, 2024

@ylacombe

I posted a PR, please check it.
Thanks : )

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Feb 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants