Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beam_sample throws a nan error on long generations #22914

Closed
2 of 4 tasks
fpgaminer opened this issue Apr 21, 2023 · 8 comments
Closed
2 of 4 tasks

beam_sample throws a nan error on long generations #22914

fpgaminer opened this issue Apr 21, 2023 · 8 comments

Comments

@fpgaminer
Copy link
Contributor

System Info

  • transformers version: 4.29.0.dev0
  • Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.35
  • Python version: 3.10.10
  • Huggingface_hub version: 0.13.4
  • Safetensors version: 0.3.0
  • PyTorch version (GPU?): 2.0.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

It seems that beam_sample throws a NaN exception when generating long sequences. Specifically the call next_tokens = torch.multinomial(probs, num_samples=2 * num_beams). Example generate call that causes the bug:

output_sequences = model.generate(
	input_ids=encoded_prompt,
	max_length=512 + len(encoded_prompt[0]),
	temperature=0.7,
	num_return_sequences=1,
	num_beams=2,
	do_sample=True,
)

Reliably throws a NaN on my system and @diegomontoya 's system. In my testing this occurs when the requested number of new tokens is roughly >=256. In the example above I use 512 just to be sure.

Based on the debugging I've done so far, what's happening is beam_scores increases exponentially with each iteration of the inner beam search loop. It does this until it reaches a very large negative number, causing next_token_scores to contain all -inf, which causes probs to be all nan and then multinomial throws.

As for why this occurs, a rough summary of the inner loop elucidates:

while
    next_token_scores = ...
    next_token_scores = next_token_scores + beam_scores
    next_token_scores = logits_warper(..., next_token_scores)
    
    beam_scores = beam_scorer.process(..., beam_scores, next_token_scores)

Specifically, beam_scores feeds back into itself with every iteration. If the inner loop was additive only, this would be fine, and beam_scores would increase linearly with length. But this is not the case. logits_warper makes the loop non-additive. In the example above it behaves as approximately multiplying next_token_scores by 1.5. Hence beam_scores goes exponential and the function eventually throws.

I don't know enough about how beam_sample is meant to function to analyze further. It does seem odd to me, though, that the sampling is dependent on the current beam score. Since the beam score is a scalar value, it affects the probabilities of all tokens equally, so ... it shouldn't have any effect at all? So why apply it to the sampling logic? It seems more reasonable to me, and would indeed fix this bug, if it were added after sampling and before handing the scores off to the BeamScorer for processing.

Expected behavior

generate shouldn't throw a nan error under reasonable circumstances.

@gante
Copy link
Member

gante commented Apr 21, 2023

Hey @fpgaminer 👋

My first recommendation would be to use "normal" sample, perhaps with a slightly lower temperature. If you think about it, beam_sample is a sample-based strategy that greedily picks the best scores among the drawn sequences, which is similar to sample with a lower temperature (which also favors high-scoring tokens). sample is also faster (no beam-related operations), and subject to much more maintenance :)

If you still want to use beam_sample, my recommendation would be to add the remove_invalid_values flag (docs).

@fpgaminer
Copy link
Contributor Author

Hello @gante,

Thanks for the response. I have no intention of using beam sampling myself. I'm bubbling up a bug report by @diegomontoya from my GPTQ-triton repo, that turned out to just be a bug in transformers itself. It was a curious enough bug that I got nerd-sniped by it...

If you still want to use beam_sample, my recommendation would be to add the remove_invalid_values flag (docs).

I don't think that would work. The bug results from beam_scores exploding, which drives all the scores down to -inf. Invalid tokens are removed in the logits_processor pass, before beam_scores is added. Even if it were applied after, it would just set all tokens to max which I think would cause softmax->multinomial to just throw anyway.


I've looked at the code more, and read up on beam search more. I think my initial take is correct. I see no reason to feed the beam_scores to the logit processors. It's a scalar value added to all the logits/probs, so what effect could it possibly have? Temperature, for example, is completely unaffected as proven like so:

Suppose we have a vector `x`
Softmax is `e**x / sum(e**x)`

Suppose we add a scalar `b`: `x + b`
Softmax is now: `e**(x + b) / sum(e**(x + b))`
Exponential law: `e**x * e**b / sum(e**x * e**b)`
Simplify: `e**x * e**b / (sum(e**x) * e**b)`
Simplify: `e**x / sum(e**x)`
Q.E.D.

It's possible that b, aka the beam score, has an effect on other logit processors, but I can't fathom what effect one would want it to have on things like top p, top k, typical, etc. I'd have to go through each in more detail to have a stronger opinion here. It just feels wrong, since I think all those logit processors were introduced in the context of greedy sampling. They weren't designed to take a global scalar like beam score into account.

So I argue that beam_sample should be modified to not include the beam_scores when calling logits_warper, and when doing multinomial sampling. It should be added after the tokens have been sampled.


I also think there is other oddness to the way beam_sample samples. Consider the simplified forms of sample vs beam_sample:

sample:

next_token_logits = outputs.logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

beam_sample:

next_token_logits = outputs.logits[:, -1, :]
next_token_scores = log_softmax(next_token_logits, dim=-1)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
... beam search stuff ...

Why does beam_sample apply a log_softmax to the logits before feeding them to logits_processor when the sample method doesn't? That seems odd, especially when all the logit processors are expecting, well, logits, not the log softmax of logits.

The same goes for logits_warper, which also applies a sequence of LogitProcessors. They aren't likely to be expecting log softmaxed values.

And then softmax gets applied afterwards to values in the log softmax domain... very confusing.


So I propose for beam_sample (simplified/pseudo):

next_token_logits = outputs.logits[:, -1, :]
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
... gather tokens, scores ...
... add beam_scores to respective scores ...
... beam processing ...

If you think about it, beam_sample is a sample-based strategy that greedily picks the best scores among the drawn sequences, which is similar to sample with a lower temperature (which also favors high-scoring tokens). sample is also faster (no beam-related operations), and subject to much more maintenance :)

My quick take: sure, maybe. But in theory beam search and beam sampling still provide potential value over low temp sampling. They can explore the landscape more thoroughly and potentially find more globally optimal sequences that a greedy sampling method usually won't. I dunno.

I'm personally in the "better logit processors" and "better models" camp than futzing with beam search. But since HF includes beam sampling, might as well make it work as well as possible?

@Qubitium
Copy link
Contributor

Qubitium commented Apr 22, 2023

@gante I am not qualified to comment on the internal code itself so I will only report from a user level perspective:

  1. Adding remove_invalid_values=True does not resolve the issue. I am still getting the exact same nan/inf exceptions with num_beams = 2 on input+output (expected) total token values > 256. I added it to both generate_config and directly to generate() method and it still threw exceptions. Am I using it correctly?

probability tensor contains either `inf`, `nan` or element < 0

  1. Having read the naive concepts of beam search and also huggingface's own interpretations of the beam search, I don't understand why user have to care about a remove_invalid_values toggle. Isn't it implied that generate wrapper, which most user and external libs use, should auto remove and bypass any invalid values during gen stages? This add another chicken and egg problem, if we don't add remove_invalid_values, only a runtime generate will find out that inf/nan tokens are generated and then we apply a remove_invalid_values pass which negates any performance. As result, as an end-user, I will always set remove_invalid_values with num_beams >1, but if the both options are symbiotic, they should be done internally by the library and not exposed to user.

  2. I am using beam search because I believe it may resolve an issue that is outlined by the beam search principle. I can lower the the temperature but that requires that:

  • I can detect my result from higher temperature is wrong, very difficult for my problem set.
  • Even if I can detect error due to higher temp, I need re-run pass in lower temp which is basically beams in operation.
  • Not possible to predetermine whether lower/higher temp result in better answer. In my test case use of beam-search. I am relying on the idea that num_beams=2 select two paths, and only until the end, compare the prob score of the result and give me the best one.

@gante
Copy link
Member

gante commented Apr 22, 2023

@fpgaminer @diegomontoya Let me split my comment in three: remove_invalid_values, how beam sample is implemented, and a suggestion based on @diegomontoya 3rd point in the last comment :)


remove_invalid_values was created to avoid errors with extreme numbers, as a last resort. When it needs to be used, it means that there is something unstable in the process. I was double-checking it and it is missing the -inf case, which is probably why it didn't immediately solve your case (I'll open a PR). However, it should still be avoided, and the cases where you actually need it are very very uncommon.

Isn't it implied that generate wrapper, which most user and external libs use, should auto remove and bypass any invalid values during gen stages?

Definitely not. Our guiding principles for building blocks like .generate(), sorted by priority, are 1. keep retrocompatibility (unless it is to fix bugs) and 2. build a default behavior that works in most cases and minimizes black-box behavior. Having remove_invalid_values on by default would go against 2 -- if there is something wrong in the generation strategy, we'd rather show it up to the user.


The same discussion and arguments you wrote about beam_sample were also written in the past, by myself included :) (a few examples: 1 2).

TL;DR: I agree with your point of view, but a) beam_sample is not an official implementation so the order of operations is not right or wrong, it is a matter of taste of its creator b) because of the principles I wrote above, ensuring retrocompatibility > individual opinion.

Our codebase is fully open, so feel free to monkey patch on your end any different perspective 🤗 And my apologies for the nerd snipe, beam methods are indeed a strong magnet!


@diegomontoya if beam sample keeps failing after I add the -inf case and monkey patching is not an option, try the following:

  1. Use sample
  2. Set num_return_sequences to an integer, which will make generate return these many sequences per input
  3. Set output_scores and return_dict_in_generate to True, so you have access to the scores
  4. Pick the output with the highest score (this function may help)

This is essentially a poor man's version of beam sample. While beam sample greedily optimizes the score in the intermediary steps, this will retain full randomness.


I hope this (long) comment helps understanding why we make certain decisions, even if you don't agree with them :)

@Qubitium
Copy link
Contributor

@gante Thank you. Got much more info than I had hoped in return and not only did it clarify it for me but your poor-man's beam really opened up my mind about how I should properly use and approach my future usage of generate as a whole.

@gante
Copy link
Member

gante commented May 1, 2023

btw, the error you've seen is very likely related to this one: #22979

TL;DR -- pytorch's sampling function is buggy atm, being able to pick tokens with 0 probability 👀

@Daryl149
Copy link

Daryl149 commented May 24, 2023

Just adding that it could be CUDA, bitsandbytes and pytorch related.

The same error happens for me as well on torch==1.13.1 with model call:
tokens = model.generate(**inputs, max_new_tokens=500, do_sample=True, temperature=0.9, streamer=streamer)

This call does not throw the error, but returns gibberish:
tokens = model.generate(**inputs, max_new_tokens=25, do_sample=True, num_beams=1, temperature=0.9, streamer=streamer, remove_invalid_values=True)
returns for example:
ováBit}")VAjem ubuntu米 alwaysicago connectingselection Rewrite perceMillBLoll Forschavano economic pygindi Pent öss fs file

For me the issue happens on my multi gpu ubuntu 22.04 system with CUDA 12.0 (python detects 11.8 interestingly).
It does not happen on my single gpu ubuntu 20.04 system with CUDA 11.6.

Also, this only happens when I load the model in 8-bit with bitsandbytes. Loading the model without load_in_8bit=True is very slow (5-10 seconds per token), but returns text that makes sense and does not throw any error.

Further testing shows that after downgrading from CUDA 11.8 to CUDA 11.6, I no longer receive this error when using load_in_8bit=True and tokens = model.generate(**inputs, max_new_tokens=25, do_sample=True, temperature=0.9, streamer=streamer). However, I still get gibberish results:
ток hastICEyk char sunny少 hardwareington chi GraphSecondsesser引 conser conformygieneOriuvimplughtub.
The winning combo for 'no error and words that make sense' seems to be either:

  • CUDA 11.6, load_in_8bit=True and a single GPU system.
  • or CUDA 11.6, load_in_8bit=False and a multi GPU system.

**Update: ** it's not pytorch related, happens for both 2.0.1 and 1.13.1. See #23989

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants