Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using run_mlm.py to pretrain a roberta base model from scratch outputs do not include <bos> or <eos> tokens #21711

Closed
2 of 4 tasks
Rallio67 opened this issue Feb 20, 2023 · 19 comments

Comments

@Rallio67
Copy link

Rallio67 commented Feb 20, 2023

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.27.0.dev0
  • Platform: Linux-5.15.0-60-generic-x86_64-with-glibc2.31
  • Python version: 3.9.16
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.13.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: deepspeed

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am attempting to train a roberta-base model using the defaults on a custom corpus.

deepspeed --num_gpus 8 run_mlm.py
--model_type roberta
--max_seq_length 128
--do_train
--per_device_train_batch_size 512
--fp16
--save_total_limit 3
--num_train_epochs 30
--deepspeed ds_config.json
--learning_rate 1e-4
--eval_steps 50
--max_eval_samples 4000
--evaluation_strategy steps
--tokenizer "roberta-large"
--warmup_steps 30000
--adam_beta1 0.9
--adam_beta2 0.98
--adam_epsilon 1e-6
--weight_decay 0.01
--lr_scheduler_type linear
--preprocessing_num_workers 8
--train_file my_text.txt
--line_by_line
--output_dir my_roberta_base

The training works and the loss goes down and the accuracy goes up. However, when I compare the outputs to the original roberta-base I see a behavior that appears to be a glitch or problem with the training.

Expected behavior

Expected behavior using roberta-base from huggingface hub shows the first and last token of the output being the <bos> and <eos> tokens, respectively, while my new trained roberta-base model is showing token #8 ( and). I think this was learned instead of being automatically set to and like the expected behavior should be for this script.

from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("roberta-base")

model1 = AutoModelForMaskedLM.from_pretrained("roberta-base", torch_dtype=torch.float16).cuda(0)
model2 = AutoModelForMaskedLM.from_pretrained("rob_wiki_base", torch_dtype=torch.float16).cuda(0)

text="The main causes of death for <mask> are human-related issues, such as habitat destruction and human objects. Their slow-moving, curious <mask> has led to violent collisions with propeller-driven boats and ships. Some manatees have been found with over 50 scars on them from propeller <mask>. Natural causes of death include adverse temperatures, predation by <mask> on young, and disease."
input = tokenizer(text, truncation=True, padding=True, return_tensors="pt")

output1=model1(input["input_ids"].cuda(0))
output2 = model2(input["input_ids"].cuda(0))

predicted_token_id1 = output1[0][0].argmax(axis=-1)
predicted_token_id2 = output2[0][0].argmax(axis=-1)

print("Original roberta-base output:")
print(predicted_token_id1)
print(tokenizer.decode(predicted_token_id1))
print("-"*20)
print("My new roberta-base output:")
print(predicted_token_id2)
print(tokenizer.decode(predicted_token_id2))
print("-"*20)

Original roberta-base output:
tensor([ 0, 133, 1049, 4685, 9, 744, 13, 18018, 32, 1050,
12, 3368, 743, 6, 215, 25, 14294, 8181, 8, 1050,
8720, 4, 2667, 2635, 12, 19838, 6, 10691, 3650, 34,
669, 7, 4153, 25062, 19, 39238, 12853, 12, 9756, 8934,
8, 7446, 4, 993, 313, 877, 293, 33, 57, 303,
19, 81, 654, 26172, 15, 106, 31, 39238, 12853, 5315,
4, 7278, 4685, 9, 744, 680, 12661, 3971, 6, 12574,
1258, 30, 22139, 15, 664, 6, 8, 2199, 4, 2],
device='cuda:0')

The main causes of death for whales are human-related issues, such as habitat destruction and human objects. Their slow-moving, curious behavior has led to violent collisions with propeller-driven boats and ships. Some manatees have been found with over 50 scars on them from propeller strikes. Natural causes of death include adverse temperatures, predation by predators on young, and disease.

My new roberta-base output:
tensor([ 8, 133, 1049, 4685, 9, 744, 13, 5868, 32, 1050,
12, 3368, 743, 6, 215, 25, 14294, 8181, 8, 1050,
8720, 4, 2667, 2635, 12, 19838, 6, 10691, 2574, 34,
669, 7, 4153, 25062, 19, 39238, 12853, 12, 9756, 8934,
8, 7446, 4, 993, 313, 877, 293, 33, 57, 303,
19, 81, 654, 26172, 15, 106, 31, 39238, 12853, 5315,
4, 7278, 4685, 9, 744, 680, 12661, 3971, 6, 12574,
1258, 30, 5868, 15, 664, 6, 8, 2199, 4, 8],
device='cuda:0')
andThe main causes of death for humans are human-related issues, such as habitat destruction and human objects. Their slow-moving, curious nature has led to violent collisions with propeller-driven boats and ships. Some manatees have been found with over 50 scars on them from propeller strikes. Natural causes of death include adverse temperatures, predation by humans on young, and disease. and

@Rallio67
Copy link
Author

Rallio67 commented Feb 21, 2023

The model config.json have a notable difference between the roberta-base and my new pretrained roberta model.

max_position_embeddings in roberta-base is equal to 514, while in my new pretrained model it is set to 512.

I also notice in the script there is a default setting to "mask special tokens"

We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it receives the special_tokens_mask.

return_special_tokens_mask=True,

Is it possible that this is the source of the issue? Thank you for any help that can be offered on this problem.

@sgugger
Copy link
Collaborator

sgugger commented Feb 21, 2023

cc @ArthurZucker and @younesbelkada

@Rallio67
Copy link
Author

Any updates on this? Would appreciate any help to identify the source of this bug.

@ArthurZucker
Copy link
Collaborator

Hey, this should probably be aske on the forum as it is not a bug and there we can reproduce your issue (the model is private).

  1. The training might have gone wrong.
  2. The generation_config or config file might be wrong. Both your bos_token and eos_token are wrong 0, and 2 changed to 8.

If you can check the eos and pad and bos token arguments and try to make sure that the inputs that you feed to the model are the same, would be great.
Also be careful with the formating of your issue, it is very hard to read. If you want an answer fast, this plays a bit against you 😉

@Rallio67
Copy link
Author

Maybe there is some misunderstanding in what I posted. To the best of my knowledge I am using an unmodified, default training script from huggingface on a plain text file using the default configuration for roberta (a model that has been on HF for 2 years or more I think). I did a fresh install from source of transformers on a 8x A100 instance.

see here:
https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py

I ran the script using only default configuration commands (they are posted above) on a text file using the default roberta configuration, but the outputs are never the correct 0 and 2. Any configuration I am using is automatically generated by the training script and then I am running the generation script exactly the same as I do with roberta-base, but substituting the model directory generated by the run_mlm.py script.

If I am running the script with all default parameters, I think it qualifies as a bug?

@ArthurZucker
Copy link
Collaborator

Okay! Thanks for clarifying, I will have a look as soon as I can. It seems like a bug indeed

@Rallio67
Copy link
Author

Rallio67 commented Feb 24, 2023

The troubleshooting I did myself on this makes me think it has something to do with the special tokens being attention masked in the training dataset preparation. Normally masking special tokens makes sense for some language models (like the <pad> token), but I think in this case for the BOS/EOS you don't want them masked. The reason token 8 is showing up in those positions is because the word "and" is extremely common and I think it minimizes overall loss by putting that token. It was never configured to use token 8 (early on in the training it would be a random token like period "." or "the" or "and". ). Overall the model is still training and working well, its just not ever generating the EOS/BOS token in the "unmasked" output.

@ArthurZucker
Copy link
Collaborator

Ok, that's fairly interesting.
Normally when generating, the bos token should be forced via the logits processor. So if you generate using model.generate I am guessing that this won't happen even if you have the masks.
It is okay if there tokens are attention masked, I think they should always be forced (during training for example, the decoder input ids should always start with the bos so that it is not predicted, and then the loss is not computed on it.
Does that make sense?

@Rallio67
Copy link
Author

The roberta-base and roberta-large models on huggingface when used with model.generate does properly create the BOS/EOS tokens. The output from my checkpoints inserts an extra first and last token, but the token is not BOS/EOS and appears to be learned.

@huggingface huggingface deleted a comment from github-actions bot Mar 27, 2023
@sarthusarth
Copy link

sarthusarth commented Apr 3, 2023

Is there any update about this issue, I'm facing the same error? @ArthurZucker

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Apr 4, 2023

The troubleshooting I did myself on this makes me think it has something to do with the special tokens being attention masked in the training dataset preparation. Normally masking special tokens makes sense for some language models (like the token), but I think in this case for the BOS/EOS you don't want them masked. The reason token 8 is showing up in those positions is because the word "and" is extremely common and I think it minimizes overall loss by putting that token. It was never configured to use token 8 (early on in the training it would be a random token like period "." or "the" or "and". ). Overall the model is still training and working well, its just not ever generating the EOS/BOS token in the "unmasked" output.
So regarding the lead posted here by @Rallio67, I think I agree with him:

  • The special tokens should not be masked when computing the loss : the reason behind this that if you want the model to learn that it has to predict the eos and bos token when computing the loss, you should not mask them. This is visible as the model ends up learning to predict the most common words at the beginning and end, instead of predicting the bos and eos.
    I suggest trying out without the special mask, and if it works for you I'll try to find a fix that does not remove backward compatibility!

@sarthusarth
Copy link

Screenshot 2023-04-04 at 5 13 48 PM
Training without special tokens also doesn't work, not sure what is the reason then

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Apr 6, 2023

Without special tokens or without special masks?

@sarthusarth
Copy link

I trained it with return_special_tokens_mask=False, but only for 3 epochs (is it possible that when I train it fully it's able to learn) ?

@ArthurZucker
Copy link
Collaborator

Yep, if you can would be great to see after the same amount of training as the model that raised the issue.

@sarthusarth
Copy link

I trained the model for 75 epochs, still and tokens are not appearing

@huggingface huggingface deleted a comment from github-actions bot May 25, 2023
@ArthurZucker ArthurZucker reopened this May 25, 2023
@huggingface huggingface deleted a comment from github-actions bot Jun 19, 2023
@ArthurZucker
Copy link
Collaborator

Hey! I won't really have time to dive deep into this one, If you could share some example inputs that are fed to the model (forgot to ask for the context of my_text.txt, but if the tokenizer does not pass bos and eos (by that I mean does not add them) it might be either the default roberta tokenizer that can't be used out of the box for this or something else.

@ArthurZucker
Copy link
Collaborator

Okay, here is a very relevant comment : #22794 (comment), it is important to make sure that when the script calls torch_mask_tokens, the loss is only computed on the masked tokens (and since there is a call to masked_fill_(special_tokens_mask, value=0.0), which creates the probability of masking special tokens, setting is to 0. This means that the next call:

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

will set the labels for eos and bos to -100 always ignoring them.

If you remove the special tokens mask, it is automatically created using get_special_tokens_mask which is why the tokens are not learned either.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants