Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam Search Fails for Llama 70b #26332

Closed
2 of 4 tasks
jconley-deloitte opened this issue Sep 21, 2023 · 9 comments · Fixed by #26843
Closed
2 of 4 tasks

Beam Search Fails for Llama 70b #26332

jconley-deloitte opened this issue Sep 21, 2023 · 9 comments · Fixed by #26843

Comments

@jconley-deloitte
Copy link

System Info

  • transformers version: 4.33.2
  • Platform: Linux-5.15.0-1041-aws-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.17.2
  • Safetensors version: 0.3.3
  • Accelerate version: 0.23.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes, 6 A100 GPUs
  • Using distributed or parallel set-up in script?: It is device_map="auto" which I believe is distributing the layers between GPUs

Who can help?

@gante Appears to the be the relevant developer because this is an issue with model.generate

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The below script only needs a change to have a TOKEN and CACHE_DIR and can be ran to generate the error.
I have tried with and without autocast, and it does not affect this.
I have also verified that the GPUs/Machine are not memory constrained.
Greedy generation works as expected, only beam-search is failing.

import os
import sys
from copy import deepcopy
import pandas as pd
import torch
from tqdm import tqdm

sys.path.append("/app")
from src.secrets import TOKEN
from src.constants import CACHE_DIR
from transformers import AutoModelForCausalLM, AutoTokenizer

tqdm.pandas()
# Use 6/8 A100 GPUs 
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5,6,7"

# Load the model
path = "meta-llama/Llama-2-70b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(path, cache_dir=CACHE_DIR, device_map="auto", torch_dtype=torch.float16, use_auth_token=TOKEN, use_cache=True)
# Setup Tokenizer
tokenizer = AutoTokenizer.from_pretrained(path, truncation_side="left", token=TOKEN)

prompts = [
    "<s>[INST]Hi, how are you?[/INST]",
    "<s>[INST]Who is the president?[/INST]",
    "<s>[INST]What continent is Ireland in?[/INST]",
    "<s>[INST]How fast can a chicken run?[/INST]",
    
]

for prompt in prompts:
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096, add_special_tokens=False).input_ids
    # Move tokens to GPU
    input_ids = input_ids.to("cuda")
    with torch.no_grad():
        with torch.cuda.amp.autocast():
        output = model.generate(input_ids, num_beams=4, max_new_tokens=512, temperature=0.5, top_p=0.9)
    gen_text = tokenizer.batch_decode(output)
    print(gen_text[0])

The resulting error

RuntimeError                              Traceback (most recent call last)
Cell In[5], line 7
      4 input_ids = input_ids.to("cuda")
      5 with torch.no_grad():
      6     #with torch.cuda.amp.autocast():
----> 7     output = model.generate(input_ids, num_beams=4, max_new_tokens=512, temperature=0.5, top_p=0.9)
      8 gen_text = tokenizer.batch_decode(output)
      9 print(gen_text[0])

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1718, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1710     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1711         input_ids=input_ids,
   1712         expand_size=generation_config.num_beams,
   1713         is_encoder_decoder=self.config.is_encoder_decoder,
   1714         **model_kwargs,
   1715     )
   1717     # 14. run beam sample
-> 1718     return self.beam_sample(
   1719         input_ids,
   1720         beam_scorer,
   1721         logits_processor=logits_processor,
   1722         logits_warper=logits_warper,
   1723         stopping_criteria=stopping_criteria,
   1724         pad_token_id=generation_config.pad_token_id,
   1725         eos_token_id=generation_config.eos_token_id,
   1726         output_scores=generation_config.output_scores,
   1727         return_dict_in_generate=generation_config.return_dict_in_generate,
   1728         synced_gpus=synced_gpus,
   1729         **model_kwargs,
   1730     )
   1732 elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
   1733     # 11. prepare beam search scorer
   1734     beam_scorer = BeamSearchScorer(
   1735         batch_size=batch_size,
   1736         num_beams=generation_config.num_beams,
   (...)
   1742         max_length=generation_config.max_length,
   1743     )

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:3392, in GenerationMixin.beam_sample(self, input_ids, beam_scorer, logits_processor, stopping_criteria, logits_warper, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   3388 next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
   3390 probs = nn.functional.softmax(next_token_scores, dim=-1)
-> 3392 next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
   3393 next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
   3395 next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

Expected behavior

The model generates tokens using beam search.

@jconley-deloitte
Copy link
Author

I also tried installing from github and received the same error. See the update environment below (same script/error)
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.34.0.dev0
  • Platform: Linux-5.15.0-1041-aws-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.16.4
  • Safetensors version: 0.3.3
  • Accelerate version: 0.23.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes, 6 A100
  • Using distributed or parallel set-up in script?: It is device_map="auto" which I believe is distributing the layers between GPUs

@jconley-deloitte
Copy link
Author

It's also worth mentioning, I noticed that for lower numbers of tokens (10 tokens generated) this error did not occur. It only happened for longer generations, such as the up to 512 token runs above.

@ArthurZucker
Copy link
Collaborator

cc @Rocketknight1 an example of how to produce the nan.
This is somewhat expected, we have quite a few issues relating to Llama and nan with batch generation. @gante is OOO but we might merge a fix similar to #25284

@Rocketknight1
Copy link
Member

Got it! I'll see if I can reproduce this and push a fix to LLaMA (which might also help bringing the code into line with the InternLM code)

@Rocketknight1
Copy link
Member

I made a much shorter reproduction script for the issue that doesn't need llama-70b - debugging is easier when I don't need to spin up 8 A100s!

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

path = "meta-llama/Llama-2-7b-chat-hf"
# Setup Tokenizer
model = AutoModelForCausalLM.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path)

prompts = [
    "<s>[INST]Hi, how are you?[/INST]",
    "<s>[INST]Who is the president?[/INST]",
    "<s>[INST]What continent is Ireland in?[/INST]",
    "<s>[INST]How fast can a chicken run?[/INST]",
    
]

for prompt in prompts:
    input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096, add_special_tokens=False).input_ids
    with torch.no_grad():
        output = model.generate(input_ids, num_beams=4, max_new_tokens=512, temperature=0.5, top_p=0.9)
    gen_text = tokenizer.batch_decode(output)
    print(gen_text[0])

The issue occurs on GPU and CPU, in float16/bfloat16/float32. It is only triggered by beam search, and doesn't occur with standard generation. Working on it!

@Rocketknight1
Copy link
Member

Further update: This issue only occurs in 'beam sample' decoding, not 'beam search'. As a temporary workaround @jconley-deloitte , you can add do_sample=False to the generate arguments to use beam search instead.

@Rocketknight1
Copy link
Member

Rocketknight1 commented Sep 29, 2023

Got it: This is nothing to do with LLaMA's code at all! The cause is that LLaMA's generation_config specifies a combination of options (temperature=0.6, top_p=0.9) that interact badly with beam search.

The reason seems to be that temperature=0.6 produces very sharp distributions, which means that top_p removes all or almost all of the tokens that can be selected in each iteration. As generation continues, this eventually results in all of the possible choices being removed by top_p. As a result, next_token_scores contains only masked -inf values, and so the softmax over this creates NaN outputs due to division by zero (because exp(-inf) == 0).

Possible solutions include tweaking the top_p warper to always return at least 1 logit, or altering beam search somehow when these warpers are present. Since it touches core generation code, I'm going to leave this fix until I can discuss it with @gante next week!

In the meantime @jconley-deloitte, you can either use do_sample=False, or call generate with different values for those arguments (e.g. temperature=1.0, top_p=1.0)

@gante
Copy link
Member

gante commented Oct 16, 2023

@Rocketknight1 thank you for diving in!

Possible solutions include tweaking the top_p warper to always return at least 1 logit

We already do this. The root of the issue is that in beam_sample we apply the logits processors before adding the beam scores, as usual in beam methods. However, for legacy reasons, we apply the logits warpers after adding the beam scores, which causes the scores to explode with temperatures below 1.0.

There has been a similar issue in the past , and, regardless of being a bug that causes crashes, I think that it makes more sense to apply the logits warpers before adding the scores:
1 - the most important ones, like top_p and top_k, keep the same tokens regardless of where the operation is applied :)
2 - we also apply the logits processors before adding the scores

All this to say that I'm going to open a PR to break the legacy behavior, as it is a recurrent issue that up takes significant time every time it pops up :) I've tested locally, and changing this detail fixes the crashing snippets!

@ArthurZucker
Copy link
Collaborator

➕ on breaking this as we have had quite a lot of issues. Having a self.legacy flag might be ok to have a deprecation cycle / just keep both for

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants