Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate attention backends #3005

Merged
merged 44 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a40b2c9
Add attention_backends
WoosukKwon Feb 23, 2024
6b6f7c7
Move
WoosukKwon Feb 23, 2024
f2b888c
Remove if
WoosukKwon Feb 23, 2024
7f4422c
Remove if
WoosukKwon Feb 23, 2024
1d9dc99
Attention
WoosukKwon Feb 23, 2024
534d0f8
Minor
WoosukKwon Feb 23, 2024
404022a
Minor
WoosukKwon Feb 23, 2024
194df2f
Rename
WoosukKwon Feb 23, 2024
a6910ea
Add flash-attn
WoosukKwon Feb 23, 2024
346b1b7
Address review
WoosukKwon Feb 27, 2024
05579fa
Merge branch 'main' into refactor-attn
WoosukKwon Feb 28, 2024
da115dd
Move
WoosukKwon Feb 29, 2024
19ecd4d
Move
WoosukKwon Feb 29, 2024
5b8e8c7
Minor
WoosukKwon Feb 29, 2024
ef8ace1
Rename
WoosukKwon Feb 29, 2024
6490fb4
Fix attention
WoosukKwon Feb 29, 2024
3baebac
Minor
WoosukKwon Feb 29, 2024
6a81692
Minor
WoosukKwon Feb 29, 2024
963a2c7
Minor
WoosukKwon Feb 29, 2024
38baed7
Add comment
WoosukKwon Feb 29, 2024
f97fc52
Merge branch 'main' into refactor-attn
WoosukKwon Mar 5, 2024
89069b8
Attention backends -> Attention
WoosukKwon Mar 5, 2024
9ba068a
Move
WoosukKwon Mar 5, 2024
677ad69
Move
WoosukKwon Mar 5, 2024
f5c7b07
Minor
WoosukKwon Mar 5, 2024
4a80dd0
Ops
WoosukKwon Mar 5, 2024
1319fc9
Minor
WoosukKwon Mar 5, 2024
281c5d5
Import Attention
WoosukKwon Mar 5, 2024
2f32381
Minor
WoosukKwon Mar 5, 2024
c68fe7e
Update models
WoosukKwon Mar 5, 2024
f65f65d
forward -> forward_decode
WoosukKwon Mar 5, 2024
45d02a1
Add packaging to requirements
WoosukKwon Mar 5, 2024
8e5ca7e
Merge branch 'main' into refactor-attn
WoosukKwon Mar 6, 2024
8333223
Remove FlashAttention from requirements
WoosukKwon Mar 6, 2024
ed1ab56
Add FlashInfer wheels to vLLM
WoosukKwon Mar 6, 2024
73aedbd
Minor
WoosukKwon Mar 6, 2024
0214afd
maybe fix packaging
WoosukKwon Mar 6, 2024
12ea60d
Add gitignore
WoosukKwon Mar 6, 2024
b460c21
Revert
WoosukKwon Mar 7, 2024
4ffa89f
revert
WoosukKwon Mar 7, 2024
6ba0e70
Copy after build
WoosukKwon Mar 7, 2024
974db99
Binary distribution
WoosukKwon Mar 7, 2024
f72560c
yapf
WoosukKwon Mar 7, 2024
0b8ac9e
Fix a bug for FP32
WoosukKwon Mar 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ _build/

# Benchmark dataset
*.json

# Third-party Python packages.
vllm/thirdparty_files/
48 changes: 45 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
import subprocess
import sys
import warnings
from pathlib import Path
from typing import List, Set
Expand All @@ -14,6 +15,8 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME

ROOT_DIR = os.path.dirname(__file__)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"

# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
Expand Down Expand Up @@ -324,8 +327,46 @@ def get_torch_arch_list() -> Set[str]:
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()

# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version = "2.5.6"
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
f"--target={install_dir}",
"einops", # Dependency of flash-attn.
f"flash-attn=={flash_attn_version}",
"--no-dependencies", # Required to avoid re-installing torch.
],
env=dict(os.environ, CC="gcc"),
)

# Copy the FlashAttention package into the vLLM package after build.
class build_ext(BuildExtension):

def run(self):
super().run()
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
self.copy_tree(install_dir, target_dir)

class BinaryDistribution(setuptools.Distribution):

def has_ext_modules(self):
return True

else:
build_ext = BuildExtension
BinaryDistribution = setuptools.Distribution
if _is_neuron():
neuronxcc_version = get_neuronxcc_version()

vllm_extension_sources = [
"csrc/cache_kernels.cu",
Expand Down Expand Up @@ -468,6 +509,7 @@ def get_requirements() -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
distclass=BinaryDistribution,
package_data=package_data,
)
2 changes: 1 addition & 1 deletion tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
Expand Down
30 changes: 23 additions & 7 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster
from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
def _configure_system():
import os
import sys

# Importing flash-attn.
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"thirdparty_files")
sys.path.insert(0, thirdparty_files)


_configure_system()
# Delete configuration function.
del _configure_system

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
from vllm.engine.llm_engine import LLMEngine # noqa: E402
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
from vllm.entrypoints.llm import LLM # noqa: E402
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
from vllm.sampling_params import SamplingParams # noqa: E402

__version__ = "0.3.3"

Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.model_executor.layers.attention.attention import Attention

__all__ = [
"Attention",
]
59 changes: 59 additions & 0 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Attention layer."""
from typing import List, Optional

import torch
import torch.nn as nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip


class Attention(nn.Module):
"""Attention layer.

This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:

1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata)
124 changes: 124 additions & 0 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional

# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from flash_attn import flash_attn_func
import torch

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)


class FlashAttentionBackend:

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.sliding_window = ((self.sliding_window, self.sliding_window) if
self.sliding_window is not None else (-1, -1))

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
value_cache, input_metadata)

if input_metadata.is_prompt:
# Prompt run.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
self.num_heads,
self.num_kv_heads,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious that why not use flash_attn_with_kvcache ? The kernel is faster than paged_attention_kernel. More benchmark details can be found in #2744

query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Comment on lines +81 to +124
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still suggest separating this out into private methods (_forward_decode, _forward_prefill etc.) so that forward can just decide which method to dispatch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your inputs! Actually, I intentionally avoided the design you proposed to ensure the flexibility in implementing the attention backends. As you pointed out, an attention backend performs 4 tasks: 1) storing the input KV tensors into the KV cache, 2) compute prefills, 3) compute prefills with prefixes, and 4) compute decodes. Currently, the two attention backends (FlashAttentionBackend and XFormersBackend) have a kernel for each task. However, this may not be necessary true in the future. For example, depending on the kernel implementation, one can compute prefills with and without prefixes (2&3) at the same time. For anther example, an attention kernel in TRT-LLM stores KV cache while computing decodes (1&4). These can be even more complicated if we implement something like Cascade inference. Hence, I believe we shouldn't fix a certain structure for the attention backends.

@Yard1 What do you think about this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should not make them part of public API, but they can be done as private APIs for the backends that do have that distinction. Basically we should try to modularize the forward method if possible as it makes it easier to read and test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. First, I believe the current implementation is easy to read; XFormersBackend is essentially the same as the current main branch and FlashAttentionBackend is simpler than that. Particularly for FlashAttentionBackend, I believe the implementation in this PR is very easy to understand.

That being said, I do agree that modularizing the backends will make it easy to test them. However, since this PR has already been delayed quite a bit, let's merge the PR and do modularization in the next PR.

Loading
Loading