-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
beam_sample throws a nan error on long generations #22914
Comments
Hey @fpgaminer 👋 My first recommendation would be to use "normal" If you still want to use |
Hello @gante, Thanks for the response. I have no intention of using beam sampling myself. I'm bubbling up a bug report by @diegomontoya from my GPTQ-triton repo, that turned out to just be a bug in
I don't think that would work. The bug results from I've looked at the code more, and read up on beam search more. I think my initial take is correct. I see no reason to feed the beam_scores to the logit processors. It's a scalar value added to all the logits/probs, so what effect could it possibly have? Temperature, for example, is completely unaffected as proven like so:
It's possible that So I argue that I also think there is other oddness to the way sample:
beam_sample:
Why does The same goes for And then So I propose for beam_sample (simplified/pseudo):
My quick take: sure, maybe. But in theory beam search and beam sampling still provide potential value over low temp sampling. They can explore the landscape more thoroughly and potentially find more globally optimal sequences that a greedy sampling method usually won't. I dunno. I'm personally in the "better logit processors" and "better models" camp than futzing with beam search. But since HF includes beam sampling, might as well make it work as well as possible? |
@gante I am not qualified to comment on the internal code itself so I will only report from a user level perspective:
|
@fpgaminer @diegomontoya Let me split my comment in three:
Definitely not. Our guiding principles for building blocks like The same discussion and arguments you wrote about TL;DR: I agree with your point of view, but a) Our codebase is fully open, so feel free to monkey patch on your end any different perspective 🤗 And my apologies for the nerd snipe, beam methods are indeed a strong magnet! @diegomontoya if beam sample keeps failing after I add the
This is essentially a poor man's version of beam sample. While beam sample greedily optimizes the score in the intermediary steps, this will retain full randomness. I hope this (long) comment helps understanding why we make certain decisions, even if you don't agree with them :) |
@gante Thank you. Got much more info than I had hoped in return and not only did it clarify it for me but your poor-man's beam really opened up my mind about how I should properly use and approach my future usage of generate as a whole. |
btw, the error you've seen is very likely related to this one: #22979 TL;DR -- pytorch's sampling function is buggy atm, being able to pick tokens with 0 probability 👀 |
Just adding that it could be CUDA, bitsandbytes and pytorch related. The same error happens for me as well on This call does not throw the error, but returns gibberish: For me the issue happens on my multi gpu ubuntu 22.04 system with CUDA 12.0 (python detects 11.8 interestingly). Also, this only happens when I load the model in 8-bit with Further testing shows that after downgrading from CUDA 11.8 to CUDA 11.6, I no longer receive this error when using
**Update: ** it's not pytorch related, happens for both 2.0.1 and 1.13.1. See #23989 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
transformers
version: 4.29.0.dev0Who can help?
@gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
It seems that
beam_sample
throws a NaN exception when generating long sequences. Specifically the callnext_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
. Example generate call that causes the bug:Reliably throws a NaN on my system and @diegomontoya 's system. In my testing this occurs when the requested number of new tokens is roughly >=256. In the example above I use 512 just to be sure.
Based on the debugging I've done so far, what's happening is
beam_scores
increases exponentially with each iteration of the inner beam search loop. It does this until it reaches a very large negative number, causingnext_token_scores
to contain all-inf
, which causesprobs
to be allnan
and thenmultinomial
throws.As for why this occurs, a rough summary of the inner loop elucidates:
Specifically, beam_scores feeds back into itself with every iteration. If the inner loop was additive only, this would be fine, and
beam_scores
would increase linearly with length. But this is not the case.logits_warper
makes the loop non-additive. In the example above it behaves as approximately multiplyingnext_token_scores
by 1.5. Hencebeam_scores
goes exponential and the function eventually throws.I don't know enough about how
beam_sample
is meant to function to analyze further. It does seem odd to me, though, that the sampling is dependent on the current beam score. Since the beam score is a scalar value, it affects the probabilities of all tokens equally, so ... it shouldn't have any effect at all? So why apply it to the sampling logic? It seems more reasonable to me, and would indeed fix this bug, if it were added after sampling and before handing the scores off to the BeamScorer for processing.Expected behavior
generate
shouldn't throw anan
error under reasonable circumstances.The text was updated successfully, but these errors were encountered: