Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate attention backends #3005

Merged
merged 44 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a40b2c9
Add attention_backends
WoosukKwon Feb 23, 2024
6b6f7c7
Move
WoosukKwon Feb 23, 2024
f2b888c
Remove if
WoosukKwon Feb 23, 2024
7f4422c
Remove if
WoosukKwon Feb 23, 2024
1d9dc99
Attention
WoosukKwon Feb 23, 2024
534d0f8
Minor
WoosukKwon Feb 23, 2024
404022a
Minor
WoosukKwon Feb 23, 2024
194df2f
Rename
WoosukKwon Feb 23, 2024
a6910ea
Add flash-attn
WoosukKwon Feb 23, 2024
346b1b7
Address review
WoosukKwon Feb 27, 2024
05579fa
Merge branch 'main' into refactor-attn
WoosukKwon Feb 28, 2024
da115dd
Move
WoosukKwon Feb 29, 2024
19ecd4d
Move
WoosukKwon Feb 29, 2024
5b8e8c7
Minor
WoosukKwon Feb 29, 2024
ef8ace1
Rename
WoosukKwon Feb 29, 2024
6490fb4
Fix attention
WoosukKwon Feb 29, 2024
3baebac
Minor
WoosukKwon Feb 29, 2024
6a81692
Minor
WoosukKwon Feb 29, 2024
963a2c7
Minor
WoosukKwon Feb 29, 2024
38baed7
Add comment
WoosukKwon Feb 29, 2024
f97fc52
Merge branch 'main' into refactor-attn
WoosukKwon Mar 5, 2024
89069b8
Attention backends -> Attention
WoosukKwon Mar 5, 2024
9ba068a
Move
WoosukKwon Mar 5, 2024
677ad69
Move
WoosukKwon Mar 5, 2024
f5c7b07
Minor
WoosukKwon Mar 5, 2024
4a80dd0
Ops
WoosukKwon Mar 5, 2024
1319fc9
Minor
WoosukKwon Mar 5, 2024
281c5d5
Import Attention
WoosukKwon Mar 5, 2024
2f32381
Minor
WoosukKwon Mar 5, 2024
c68fe7e
Update models
WoosukKwon Mar 5, 2024
f65f65d
forward -> forward_decode
WoosukKwon Mar 5, 2024
45d02a1
Add packaging to requirements
WoosukKwon Mar 5, 2024
8e5ca7e
Merge branch 'main' into refactor-attn
WoosukKwon Mar 6, 2024
8333223
Remove FlashAttention from requirements
WoosukKwon Mar 6, 2024
ed1ab56
Add FlashInfer wheels to vLLM
WoosukKwon Mar 6, 2024
73aedbd
Minor
WoosukKwon Mar 6, 2024
0214afd
maybe fix packaging
WoosukKwon Mar 6, 2024
12ea60d
Add gitignore
WoosukKwon Mar 6, 2024
b460c21
Revert
WoosukKwon Mar 7, 2024
4ffa89f
revert
WoosukKwon Mar 7, 2024
6ba0e70
Copy after build
WoosukKwon Mar 7, 2024
974db99
Binary distribution
WoosukKwon Mar 7, 2024
f72560c
yapf
WoosukKwon Mar 7, 2024
0b8ac9e
Fix a bug for FP32
WoosukKwon Mar 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ sentencepiece # Required for LLaMA tokenizer.
numpy
torch == 2.1.2
transformers >= 4.38.0 # Required for Gemma.
flash_attn == 2.5.5
xformers == 0.0.23.post1 # Required for CUDA 12.1.
fastapi
uvicorn[standard]
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
from vllm.model_executor.layers.attention_backends.prefix_prefill import (
context_attention_fwd)
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
Expand Down
331 changes: 19 additions & 312 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,22 @@
"""Multi-head attention."""
"""Attention layer."""
from typing import List, Optional

import importlib
import torch
import torch.nn as nn
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)

from vllm._C import ops
from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
context_attention_fwd)
from vllm.utils import is_hip

_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512


class PagedAttention(nn.Module):
"""MHA/MQA/GQA layer with PagedAttention.
class Attention(nn.Module):
"""Attention layer.

This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:

1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""

Expand All @@ -43,55 +30,18 @@ def __init__(
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

self.use_ref_attention = self.check_use_ref_attention()

def check_use_ref_attention(self) -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None

def ref_masked_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min

attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out
if not is_hip() and torch.cuda.get_device_capability()[0] >= 8:
# Ampere or later NVIDIA GPUs.
from vllm.model_executor.layers.attention_backends.flash_attn import FlashAttentionBackend
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
from vllm.model_executor.layers.attention_backends.xformers import XFormersBackend
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)

def forward(
self,
Expand All @@ -102,248 +52,5 @@ def forward(
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""PagedAttention forward pass.

Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)

if input_metadata.is_prompt:
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])

# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)

if self.use_ref_attention:
output = self.ref_masked_attention(
query,
key,
value,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
return output.reshape(batch_size, seq_len, hidden_size)

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
else:
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)

else:
# Decoding run.
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]

# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
batch_size,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias)
return attn_bias


def _paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty_like(query)

block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
alibi_slopes,
input_metadata.kv_cache_dtype,
)
return output
return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata)
Loading
Loading