-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative decoding 1/9] Optimized rejection sampler #2336
[Speculative decoding 1/9] Optimized rejection sampler #2336
Conversation
d032887
to
434c525
Compare
The next PR will be cadedaniel#1, will create it once this is merged. |
|
||
# Create masks using the indices. | ||
indices = torch.arange(k, device=accepted.device).unsqueeze(0) | ||
accepted_mask = indices < limits.unsqueeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between accepted
and accepted_mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accepted
is the result of the rejection sampling condition. accepted_mask
is True
up until the first position rejected by the rejection sampling condition.
Example for k=3, bs=5:
>>> accepted
tensor([[ True, False, True],
[False, False, False],
[ True, True, False],
[ True, True, True],
[False, True, False]])
>>> accepted_mask
tensor([[ True, False, False],
[False, False, False],
[ True, True, False],
[ True, True, True],
[False, False, False]])
super().__init__() | ||
self.probs_dtype = torch.float32 | ||
self.token_id_dtype = torch.int64 | ||
self._num_bonus_tokens = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when can num_bonus_tokens
> 1? Is it the last generated token by the target model iff all drafted tokens are accepted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when can num_bonus_tokens > 1?
It is always 1. This variable is for readability only. I'll add a comment.
Is it the last generated token by the target model iff all drafted tokens are accepted?
Yep!
Very exciting work! I hope this feature can be merged soon as many other framworks such as TGI, TRT-LLM, llama.cpp, gpt-fast have supported Speculative sampling. |
f = torch.clamp(difference, min=self._smallest_positive_value) | ||
|
||
# shape [batch_size, k, vocab_size] | ||
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: torch.multinomial
does not require the probability to be normalized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will leave this in to keep the maths consistent with https://arxiv.org/pdf/2302.01318.pdf. This operation is not the compute or scheduling bottleneck.
Speculative decoding
This PR is a part of a larger series of PRs implementing speculative decoding, contributed to open source vLLM by Anyscale. See #2188 and Speculative decoding open sourcing plan for more information.
Rejection sampling
This PR implements optimized rejection sampling, including the following features:
It also contributes tests which verify the rejection sampler's ability to approximate distributions, given enough samples.
The following people contributed to it: @cadedaniel @Yard1 @amogkam
Details
The basic idea behind rejection sampling is that one can sample from the target distribution (larger model) using samples from a proposal distribution (smaller draft model), while guaranteeing the output distribution is equivalent to the target distribution.
"Modified" rejection sampling is introduced in the paper. It ensures that at least one token will always be emitted from the rejection sampling routine, even if all proposal tokens are rejected.
With LLMs, modified rejection sampling can reduce latency because multiple proposal sequences can be evaluated at once (batching on the GPU).
Finally, the paper introduces the notion of a "bonus" token. In the case where all proposed tokens are accepted, an additional token can be emitted. This is possible by having the target model predict the next token given the entire proposed sequence as context.
Visual confirmation that modified rejection sampling approximates the target distribution:
code for visualizations: https://gist.github.com/cadedaniel/07c1cd4ac003f51140b205580ac02613