Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate attention backends #3005

Merged
merged 44 commits into from
Mar 7, 2024
Merged

Separate attention backends #3005

merged 44 commits into from
Mar 7, 2024

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Feb 23, 2024

This PR refactors the attention layer. Specifically, it separates the code paths for Ampere or more recent NVIDIA GPUs (which can directly use FlashAttention) and other GPUs, so that the code for the former becomes much simpler. This PR will also bring some performance improvements for ALiBi models, since we now directly call FlashAttention instead of using xformers in the middle.

@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 What do you think about this design? Please note that while I used flash_attn for now, but this will be replaced with FlashInfer.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general the refactor LGTM. My only small concern is on the learning cost of AttentionFactory since it does not completely behave like a torch nn.module. I think this can add difficulty for people adding new models.

vllm/model_executor/layers/attention/__init__.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/attention/__init__.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/attention/non_flash.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/attention/paged_attn.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/attention/__init__.py Outdated Show resolved Hide resolved
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefix-enabled attention and decoding part is the same as that in non_flash.py. Could we move them into BaseAttention? Just like:

class BaseAttention(nn.Module):
    def forward(self, ...):
        if input_metadata.is_prompt:
            if ...:
                self._do_prompt_attention()
            else:
                # prefix-enabled attention
        else:
            # Decoding run.

    def _do_prompt_attention(self):
        # use xformers or flash_attn here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @chenxu2048, thanks for your input. I intentionally avoided this design since some attention implementation may not follow the structure. For example, an attention kernel may process the prompt attention and prefix-enabled attention together. In terms of flexibility, I think the current structure is preferable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your explanation.

@WoosukKwon WoosukKwon marked this pull request as ready for review February 29, 2024 00:47
@WoosukKwon
Copy link
Collaborator Author

@zhuohan123 PTAL. Please note that I intentionally didn't make changes to other models than Llama.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! In general LGTM. Will you change all other model files when you merge the PR?

Comment on lines +81 to +124
if input_metadata.is_prompt:
# Prompt run.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
self.num_heads,
self.num_kv_heads,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still suggest separating this out into private methods (_forward_decode, _forward_prefill etc.) so that forward can just decide which method to dispatch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your inputs! Actually, I intentionally avoided the design you proposed to ensure the flexibility in implementing the attention backends. As you pointed out, an attention backend performs 4 tasks: 1) storing the input KV tensors into the KV cache, 2) compute prefills, 3) compute prefills with prefixes, and 4) compute decodes. Currently, the two attention backends (FlashAttentionBackend and XFormersBackend) have a kernel for each task. However, this may not be necessary true in the future. For example, depending on the kernel implementation, one can compute prefills with and without prefixes (2&3) at the same time. For anther example, an attention kernel in TRT-LLM stores KV cache while computing decodes (1&4). These can be even more complicated if we implement something like Cascade inference. Hence, I believe we shouldn't fix a certain structure for the attention backends.

@Yard1 What do you think about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should not make them part of public API, but they can be done as private APIs for the backends that do have that distinction. Basically we should try to modularize the forward method if possible as it makes it easier to read and test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. First, I believe the current implementation is easy to read; XFormersBackend is essentially the same as the current main branch and FlashAttentionBackend is simpler than that. Particularly for FlashAttentionBackend, I believe the implementation in this PR is very easy to understand.

That being said, I do agree that modularizing the backends will make it easy to test them. However, since this PR has already been delayed quite a bit, let's merge the PR and do modularization in the next PR.

Comment on lines +64 to +111
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
return output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto as in previous comment (_forward_decode_v1, _forward_decode_v2)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Let's do it in the next PR.

@WoosukKwon WoosukKwon merged commit 2daf23a into main Mar 7, 2024
20 of 23 checks passed
@WoosukKwon WoosukKwon deleted the refactor-attn branch March 7, 2024 10:03
WoosukKwon added a commit that referenced this pull request Mar 8, 2024
grandiose-pizza pushed a commit to grandiose-pizza/vllm-jais that referenced this pull request Mar 9, 2024
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious that why not use flash_attn_with_kvcache ? The kernel is faster than paged_attention_kernel. More benchmark details can be found in #2744

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants