-
-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Add attention sinks #3515
Draft
felixzhu555
wants to merge
90
commits into
vllm-project:main
Choose a base branch
from
felixzhu555:add_attention_sinks
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Hi, @felixzhu555 . it is https://arxiv.org/abs/2309.17453 right? |
Yep, trying to implement the logic from that paper. Their repo is https://github.com/mit-han-lab/streaming-llm. |
We need to @rlouf to the PR the guy in charge of outline, it seems that your PR is failing on the guided part. |
simon-mo
reviewed
Apr 25, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
This PR adds experimental support for attention sinks (#1304), based on this paper and repo. Support is currently limited to RoPE and ALiBi models (e.g. Llama, Mistral/Mixtral, Falcon, Bloom, MPT). The attention sink is hard-coded as the first block of tokens in a sequence.
Usage
Set
use_attention_sinks=True
when instantiatingLLM
orLLMEngine
, or set the--use-attention-sinks
CLI argument. Also setenforce_eager=True
(attention sinks currently does not work with CUDA graphs), and ensure the attention backend being used is FlashAttention, XFormers, or FlashInfer (WIP).Background
Experiments show that the attention mechanism heavily attends to the first few tokens of the sequence being completed, regardless of what the tokens are. Once sequence length exceeds the context length of a model, and we start evicting tokens from the beginning of the KV cache (in a sliding window fashion), the model will generate garbage (high perplexity).
This is where attention sinks come in. By always preserving the KVs for the first few tokens of the sequence while using a sliding window approach for the rest of the KV cache, the model can continue to generate sensible output (low perplexity). Theoretically, the model can stream indefinitely, as long as cache eviction is handled properly. Note the sliding window length is the model's context length.
Example
Suppose our model's context length is 2048, which equals 128 blocks of 16 tokens. Let's pass in a prompt of 2000 tokens. For the next 48 generated tokens, nothing changes; we end up filling 128 blocks so far.
Normally, vLLM forces generation to stop here since the model's context length has been reached. However, using attention sinks we bypass this stopping condition and keep generating.
At the next decode, we are writing the 2049th token to the cache and computing the 2050th token (1-based indexing). Here, we edit the block table to be
[block_table[0]] + block_table[2:]
, where we effectively ignore the 2nd block while retaining the 1st block, which is our attention sink. Notice how the block table is still length 128 because the 129th block was just allocated for token 2049. This modified block table is then used in the attention kernel.Every 16th decode that follows will ignore an additional block, but always retain the 1st block as the sink.
Modifications
This PR adds a
StreamingAttentionSink
layer that computes attention using modified block tables with the "sink" block concatenated with the remaining sliding window blocks. In the RoPE case, we always store pre-rope keys into the cache, and extra work must be done at every decode to rotate all keys for a sequence based on their new positions in the cache. Note: due to this extra work, using attention sinks incurs a significant drop in tokens/s for RoPE models (around 50-70% for Llama).use_attention_sinks
is now an argument toLLMEngine
, which passes it to the model runner and injects attention sinks into the model's modules. On every forward call of the model's attention layer, normal attention logic is replaced byStreamingAttentionSink
logic.The scheduler evicts (frees) a block (the "ignored" block) whenever a new block is allocated past the model's context length, such that the total number of used blocks is capped at
max_model_len // block_size
.Future Work
StreamingAttentionSink
assumes only 1 token is generated every decode.StreamingAttentionSink
directly edits the block table for every decode (past the context length), so the hash table for prefix caching cannot be used currently.