Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative decoding 1/9] Optimized rejection sampler #2336

Merged
merged 10 commits into from
Jan 9, 2024

Conversation

cadedaniel
Copy link
Collaborator

Speculative decoding

This PR is a part of a larger series of PRs implementing speculative decoding, contributed to open source vLLM by Anyscale. See #2188 and Speculative decoding open sourcing plan for more information.

Rejection sampling

This PR implements optimized rejection sampling, including the following features:

  • Implementation of modified rejection sampling as described in https://arxiv.org/pdf/2302.01318.pdf
  • All operations are batched on GPU, allowing non-blocking computation.
  • Efficient collection of metrics regarding acceptance rate and number of emitted tokens.

It also contributes tests which verify the rejection sampler's ability to approximate distributions, given enough samples.

The following people contributed to it: @cadedaniel @Yard1 @amogkam

Details

The basic idea behind rejection sampling is that one can sample from the target distribution (larger model) using samples from a proposal distribution (smaller draft model), while guaranteeing the output distribution is equivalent to the target distribution.

"Modified" rejection sampling is introduced in the paper. It ensures that at least one token will always be emitted from the rejection sampling routine, even if all proposal tokens are rejected.

With LLMs, modified rejection sampling can reduce latency because multiple proposal sequences can be evaluated at once (batching on the GPU).

Finally, the paper introduces the notion of a "bonus" token. In the case where all proposed tokens are accepted, an additional token can be emitted. This is possible by having the target model predict the next token given the entire proposed sequence as context.

Visual confirmation that modified rejection sampling approximates the target distribution:
rejsample1
rejsample2

code for visualizations: https://gist.github.com/cadedaniel/07c1cd4ac003f51140b205580ac02613

@cadedaniel
Copy link
Collaborator Author

@cadedaniel
Copy link
Collaborator Author

The next PR will be cadedaniel#1, will create it once this is merged.


# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between accepted and accepted_mask?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accepted is the result of the rejection sampling condition. accepted_mask is True up until the first position rejected by the rejection sampling condition.

Example for k=3, bs=5:

>>> accepted
tensor([[ True, False,  True],
        [False, False, False],
        [ True,  True, False],
        [ True,  True,  True],
        [False,  True, False]])
>>> accepted_mask
tensor([[ True, False, False],
        [False, False, False],
        [ True,  True, False],
        [ True,  True,  True],
        [False, False, False]])

super().__init__()
self.probs_dtype = torch.float32
self.token_id_dtype = torch.int64
self._num_bonus_tokens = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when can num_bonus_tokens > 1? Is it the last generated token by the target model iff all drafted tokens are accepted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when can num_bonus_tokens > 1?

It is always 1. This variable is for readability only. I'll add a comment.

Is it the last generated token by the target model iff all drafted tokens are accepted?

Yep!

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Jan 9, 2024

Very exciting work! I hope this feature can be merged soon as many other framworks such as TGI, TRT-LLM, llama.cpp, gpt-fast have supported Speculative sampling.

f = torch.clamp(difference, min=self._smallest_positive_value)

# shape [batch_size, k, vocab_size]
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: torch.multinomial does not require the probability to be normalized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave this in to keep the maths consistent with https://arxiv.org/pdf/2302.01318.pdf. This operation is not the compute or scheduling bottleneck.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants