Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.add_embeddings not getting the right embedding size #383

Closed
4 tasks
eugene-yang opened this issue Jul 4, 2022 · 0 comments · Fixed by #386
Closed
4 tasks

.add_embeddings not getting the right embedding size #383

eugene-yang opened this issue Jul 4, 2022 · 0 comments · Fixed by #386
Labels
bug Something isn't working

Comments

@eugene-yang
Copy link

Environment info

  • adapter-transformers version: v3.0.1+ (commit 11bd9d2)
  • Platform: Arch Linux
  • Python version: 3.10
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Information

Model I am using (Bert, XLNet ...): XLMR

Language I am using the model on (English, Chinese ...):

Adapter setup I am using (if any):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

I'm not sure whether it should be an upstream issue from transformers where they did not update the .vocab_size attribute properly or if it is the intended behavior. But I believe we should respect the actual size of the vocabulary size the user intends to use.
https://github.com/adapter-hub/adapter-transformers/blob/master/src/transformers/adapters/model_mixin.py#L155

Steps to reproduce the behavior:

from transformers import AdapterSetup, AutoAdapterModel, AutoTokenizer, AdapterConfig
model = AutoAdapterModel.from_pretrained('xlm-roberta-base')
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
tokenizer_new = AutoTokenizer.from_pretrained('xlm-roberta-base')
tokenizer_new.add_tokens(['[unused1]'])
tokenizer.vocab_size # > 250002
len(tokenizer_new) # > 250003
tokenizer_new.vocab_size # > 250002

model.add_embeddings('new', tokenizer_new, reference_tokenizer=tokenizer, reference_embedding='default')
model.base_model.loaded_embeddings
# > {'default': Embedding(250002, 768, padding_idx=1),  'new': Embedding(250002, 768)}

Expected behavior

The input dimension of the new embeddings at the end of the example should be 250003.

model.add_embeddings('new', tokenizer_new, reference_tokenizer=tokenizer, reference_embedding='default')
model.base_model.loaded_embeddings
# > {'default': Embedding(250002, 768, padding_idx=1),  'new': Embedding(250003, 768)}

I believe (tested) an easy fix would be changing this line to embedding = nn.Embedding(len(tokenizer), embedding_dim).

But I'm not sure whether this issue should be fixed here in the adapter or in the upstream transformer tokenizer code.

@eugene-yang eugene-yang added the bug Something isn't working label Jul 4, 2022
calpt added a commit that referenced this issue Jul 11, 2022
- Introduces a new `EmbeddingAdaptersWrapperMixin` to make embedding methods available to heads model classes. This is implemented in new per-model heads mixins. Closes #382.
- Fixes size issues with embeddings. Closes #383.
- Detach embedding weights before cloning. Closes #384.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant