Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid nan during sampling in generate() #17937

Closed
wants to merge 2 commits into from

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Jun 29, 2022

What does this PR do?

Fix CI test error

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
>           next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
E           RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

in
https://github.com/huggingface/transformers/runs/6959698965?check_suite_focus=true

The test test_sample_generate may still fail at

self.assertListEqual(output_sample.tolist(), output_generate.tolist())

for some unknown reason. I think it is better to investigate this in another PR.

@ydshieh ydshieh changed the title fix nan during sampling fix nan during sampling in generate() Jun 29, 2022
@ydshieh ydshieh changed the title fix nan during sampling in generate() Avoid nan during sampling in generate() Jun 29, 2022
@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 29, 2022

I have some doubts here, as this will make all tokens having equal probability to be sampled. But with all -inf, nothing could be sampled which leads to error. I feel there is no well-defined expected results in such edge cases.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 29, 2022

The documentation is not available anymore as the PR was closed or merged.

@@ -1970,8 +1970,19 @@ def sample(
else (outputs.hidden_states,)
)

# To avoid all `-inf` along the vocab dimension (dim -1), which gives `nan` after `softmax` and error
# in `torch.multinomial`.
_next_token_scores = torch.max(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh, softmax should be able to handle -inf correctly actually.
You can try:

torch.nn.functional.softmax(torch.tensor([0, float("-inf")]))

which works as mathematically expected.

It's only when all values are -inf that it doesn't work in which case this fix won't help because the generation is broken.

Copy link
Collaborator Author

@ydshieh ydshieh Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fix the nan issue actually. The concern is that it doesn't really make sense, as it changes the probability to uniform distribution along vocab dim, while in the broken cases, it is nothing can't be sampled (all probability 0 , mathematically)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained here: https://github.com/huggingface/transformers/pull/17937/files#r910424357
this won't fix the problem. Also note that generation is used a lot so it's every additional operation (torch.max(...)) leads to a tiny slow down.

Usually if you get nan's after the softmax it means that the generation is broken anyways which can happen and I think there is little we can do against it

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jun 30, 2022

Yes, that happens only when all -inf along the vocab dim. I will close this PR, and we have to maybe create a doc with all possible flaky tests :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants