From 7184f4b08d4a42918bae0d91fb8755e5b844e1c3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 7 Mar 2024 01:45:50 -0800 Subject: [PATCH] Separate attention backends (#3005) --- .gitignore | 3 + setup.py | 48 +++- tests/kernels/test_prefix_prefill.py | 2 +- vllm/__init__.py | 30 ++- .../layers/attention/__init__.py | 5 + .../layers/attention/attention.py | 59 +++++ .../backends}/__init__.py | 0 .../layers/attention/backends/flash_attn.py | 124 ++++++++++ .../backends/xformers.py} | 216 +++++------------- .../layers/attention/ops/__init__.py | 0 .../layers/attention/ops/paged_attn.py | 138 +++++++++++ .../ops}/prefix_prefill.py | 0 vllm/model_executor/models/baichuan.py | 13 +- vllm/model_executor/models/bloom.py | 10 +- vllm/model_executor/models/chatglm.py | 4 +- vllm/model_executor/models/deepseek.py | 10 +- vllm/model_executor/models/falcon.py | 28 +-- vllm/model_executor/models/gemma.py | 10 +- vllm/model_executor/models/gpt2.py | 6 +- vllm/model_executor/models/gpt_bigcode.py | 10 +- vllm/model_executor/models/gpt_j.py | 4 +- vllm/model_executor/models/gpt_neox.py | 4 +- vllm/model_executor/models/internlm2.py | 10 +- vllm/model_executor/models/llama.py | 12 +- vllm/model_executor/models/mixtral.py | 4 +- vllm/model_executor/models/mixtral_quant.py | 4 +- vllm/model_executor/models/mpt.py | 12 +- vllm/model_executor/models/olmo.py | 8 +- vllm/model_executor/models/opt.py | 8 +- vllm/model_executor/models/orion.py | 10 +- vllm/model_executor/models/phi.py | 4 +- vllm/model_executor/models/qwen.py | 4 +- vllm/model_executor/models/qwen2.py | 12 +- vllm/model_executor/models/stablelm.py | 10 +- vllm/model_executor/models/starcoder2.py | 4 +- 35 files changed, 558 insertions(+), 268 deletions(-) create mode 100644 vllm/model_executor/layers/attention/__init__.py create mode 100644 vllm/model_executor/layers/attention/attention.py rename vllm/model_executor/layers/{triton_kernel => attention/backends}/__init__.py (100%) create mode 100644 vllm/model_executor/layers/attention/backends/flash_attn.py rename vllm/model_executor/layers/{attention.py => attention/backends/xformers.py} (56%) create mode 100644 vllm/model_executor/layers/attention/ops/__init__.py create mode 100644 vllm/model_executor/layers/attention/ops/paged_attn.py rename vllm/model_executor/layers/{triton_kernel => attention/ops}/prefix_prefill.py (100%) diff --git a/.gitignore b/.gitignore index b5195629e5cf3..0b14c98270c41 100644 --- a/.gitignore +++ b/.gitignore @@ -184,3 +184,6 @@ _build/ # Benchmark dataset *.json + +# Third-party Python packages. +vllm/thirdparty_files/ diff --git a/setup.py b/setup.py index 745b5a9b2d02a..57d7a139e8237 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import os import re import subprocess +import sys import warnings from pathlib import Path from typing import List, Set @@ -14,6 +15,8 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) +# This is a temporary directory to store third-party packages. +THIRDPARTY_SUBDIR = "vllm/thirdparty_files" # If you are developing the C++ backend of vLLM, consider building vLLM with # `python setup.py develop` since it will give you incremental builds. @@ -324,8 +327,46 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS_PUNICA, }, )) -elif _is_neuron(): - neuronxcc_version = get_neuronxcc_version() + + # Download the FlashAttention package. + # Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530 + flash_attn_version = "2.5.6" + install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR) + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-q", + f"--target={install_dir}", + "einops", # Dependency of flash-attn. + f"flash-attn=={flash_attn_version}", + "--no-dependencies", # Required to avoid re-installing torch. + ], + env=dict(os.environ, CC="gcc"), + ) + + # Copy the FlashAttention package into the vLLM package after build. + class build_ext(BuildExtension): + + def run(self): + super().run() + target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + self.copy_tree(install_dir, target_dir) + + class BinaryDistribution(setuptools.Distribution): + + def has_ext_modules(self): + return True + +else: + build_ext = BuildExtension + BinaryDistribution = setuptools.Distribution + if _is_neuron(): + neuronxcc_version = get_neuronxcc_version() vllm_extension_sources = [ "csrc/cache_kernels.cu", @@ -468,6 +509,7 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {}, + cmdclass={"build_ext": build_ext} if not _is_neuron() else {}, + distclass=BinaryDistribution, package_data=package_data, ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index c068b38a66910..e881cd1ec3753 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,7 +3,7 @@ import time import torch -from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( +from vllm.model_executor.layers.attention.ops.prefix_prefill import ( context_attention_fwd) from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask diff --git a/vllm/__init__.py b/vllm/__init__.py index f1e30f5eb6e6e..59f1345b58d42 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,12 +1,28 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_cluster -from vllm.entrypoints.llm import LLM -from vllm.outputs import CompletionOutput, RequestOutput -from vllm.sampling_params import SamplingParams + +# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11 +def _configure_system(): + import os + import sys + + # Importing flash-attn. + thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)), + "thirdparty_files") + sys.path.insert(0, thirdparty_files) + + +_configure_system() +# Delete configuration function. +del _configure_system + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 +from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402 +from vllm.engine.llm_engine import LLMEngine # noqa: E402 +from vllm.engine.ray_utils import initialize_cluster # noqa: E402 +from vllm.entrypoints.llm import LLM # noqa: E402 +from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402 +from vllm.sampling_params import SamplingParams # noqa: E402 __version__ = "0.3.3" diff --git a/vllm/model_executor/layers/attention/__init__.py b/vllm/model_executor/layers/attention/__init__.py new file mode 100644 index 0000000000000..1c42a3d28f976 --- /dev/null +++ b/vllm/model_executor/layers/attention/__init__.py @@ -0,0 +1,5 @@ +from vllm.model_executor.layers.attention.attention import Attention + +__all__ = [ + "Attention", +] diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py new file mode 100644 index 0000000000000..830e82e10f7ad --- /dev/null +++ b/vllm/model_executor/layers/attention/attention.py @@ -0,0 +1,59 @@ +"""Attention layer.""" +from typing import List, Optional + +import torch +import torch.nn as nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.utils import is_hip + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and + torch.get_default_dtype() in (torch.float16, torch.bfloat16)): + # Ampere or later NVIDIA GPUs. + # NOTE(woosuk): FlashAttention does not support FP32. + from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend + self.backend = FlashAttentionBackend(num_heads, head_size, scale, + num_kv_heads, alibi_slopes, + sliding_window) + else: + # Turing and Volta NVIDIA GPUs or AMD GPUs. + # Or FP32 on any GPU. + from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend + self.backend = XFormersBackend(num_heads, head_size, scale, + num_kv_heads, alibi_slopes, + sliding_window) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + return self.backend.forward(query, key, value, key_cache, value_cache, + input_metadata) diff --git a/vllm/model_executor/layers/triton_kernel/__init__.py b/vllm/model_executor/layers/attention/backends/__init__.py similarity index 100% rename from vllm/model_executor/layers/triton_kernel/__init__.py rename to vllm/model_executor/layers/attention/backends/__init__.py diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py new file mode 100644 index 0000000000000..512f4e49c7eb2 --- /dev/null +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -0,0 +1,124 @@ +"""Attention layer with Flash and PagedAttention.""" +from typing import List, Optional + +# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/. +from flash_attn import flash_attn_func +import torch + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention.ops.paged_attn import ( + PagedAttentionImpl) + + +class FlashAttentionBackend: + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.sliding_window = ((self.sliding_window, self.sliding_window) if + self.sliding_window is not None else (-1, -1)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for the inputs. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + value_cache, input_metadata) + + if input_metadata.is_prompt: + # Prompt run. + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): + # normal attention + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + # prefix-enabled attention + output = PagedAttentionImpl.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + input_metadata, + self.num_heads, + self.num_kv_heads, + self.alibi_slopes, + ) + else: + # Decoding run. + output = PagedAttentionImpl.forward_decode( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention/backends/xformers.py similarity index 56% rename from vllm/model_executor/layers/attention.py rename to vllm/model_executor/layers/attention/backends/xformers.py index 2a82325b80213..bad2a648b6703 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention/backends/xformers.py @@ -1,37 +1,19 @@ -"""Multi-head attention.""" +"""Attention layer with xFormers and PagedAttention.""" +import importlib from typing import List, Optional -import importlib import torch -import torch.nn as nn from xformers import ops as xops from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, LowerTriangularMaskWithTensorBias) -from vllm._C import ops -from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( - context_attention_fwd) +from vllm.model_executor.layers.attention.ops.paged_attn import ( + PagedAttentionImpl) from vllm.utils import is_hip -_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 - - -class PagedAttention(nn.Module): - """MHA/MQA/GQA layer with PagedAttention. - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Reshape and store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention using either - xformers or the PagedAttention custom op. - 3. Return the output tensor. - """ +class XFormersBackend: def __init__( self, @@ -42,7 +24,6 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: - super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -50,48 +31,17 @@ def __init__( self.sliding_window = sliding_window if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.alibi_slopes = alibi_slopes assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") - if self.head_size not in _SUPPORTED_HEAD_SIZES: - raise ValueError(f"head_size ({self.head_size}) is not supported. " - f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") - - self.use_ref_attention = self.check_use_ref_attention() - - def check_use_ref_attention(self) -> bool: - if not is_hip(): - return False - # For ROCm, check whether flash attention is installed or not. - # if not, use_ref_attention needs to be True - return importlib.util.find_spec("flash_attn") is None - - def ref_masked_attention( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - ) -> torch.Tensor: - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - - attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query, - key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out + self.use_ref_attention = _check_use_ref_attention() def forward( self, @@ -102,7 +52,7 @@ def forward( value_cache: Optional[torch.Tensor], input_metadata: InputMetadata, ) -> torch.Tensor: - """PagedAttention forward pass. + """Forward pass with xFormers and PagedAttention. Args: query: shape = [batch_size, seq_len, num_heads * head_size] @@ -127,19 +77,14 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - input_metadata.slot_mapping.flatten(), - input_metadata.kv_cache_dtype, - ) + PagedAttentionImpl.reshape_and_cache(key, value, key_cache, + value_cache, input_metadata) if input_metadata.is_prompt: - # normal attention + # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + # normal attention if self.num_kv_heads != self.num_heads: # As of Nov 2023, xformers only supports MHA. For MQA/GQA, # project the key and value tensors to the desired number of @@ -175,13 +120,19 @@ def forward( seq_len, query.dtype) if self.use_ref_attention: - output = self.ref_masked_attention( + output = _ref_masked_attention( query, key, value, + self.num_heads, + self.num_kv_heads, + self.head_size, + self.scale, ) - # Using view got RuntimeError: view size is not compatible with input tensor's size and stride - # (at least one dimension spans across two contiguous subspaces). Use reshape instead + # Using view got RuntimeError: view size is not compatible + # with input tensor's size and stride (at least one + # dimension spans across two contiguous subspaces). + # Use reshape instead. return output.reshape(batch_size, seq_len, hidden_size) # TODO(woosuk): Too many view operations. Let's try to reduce @@ -206,27 +157,21 @@ def forward( (is_hip()) else None, ) output = out.view_as(query) + else: # prefix-enabled attention - output = torch.empty_like(query) - context_attention_fwd( + output = PagedAttentionImpl.forward_prefix( query, key, value, - output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] - input_metadata.start_loc, - input_metadata.prompt_lens, - input_metadata.context_lens, - input_metadata.max_seq_len, - getattr(self, "alibi_slopes", None), + input_metadata, + self.alibi_slopes, ) - else: # Decoding run. - output = _paged_attention( + output = PagedAttentionImpl.forward_decode( query, key_cache, value_cache, @@ -274,76 +219,37 @@ def _make_alibi_bias( return attn_bias -def _paged_attention( +def _check_use_ref_attention() -> bool: + if not is_hip(): + return False + # For ROCm, check whether flash attention is installed or not. + # if not, use_ref_attention needs to be True + return importlib.util.find_spec("flash_attn") is None + + +def _ref_masked_attention( query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, num_kv_heads: int, + head_size: int, scale: float, - alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: - output = torch.empty_like(query) - - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - input_metadata.kv_cache_dtype, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - input_metadata.kv_cache_dtype, - ) - return output + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out diff --git a/vllm/model_executor/layers/attention/ops/__init__.py b/vllm/model_executor/layers/attention/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/attention/ops/paged_attn.py b/vllm/model_executor/layers/attention/ops/paged_attn.py new file mode 100644 index 0000000000000..c5a9618c2395b --- /dev/null +++ b/vllm/model_executor/layers/attention/ops/paged_attn.py @@ -0,0 +1,138 @@ +from typing import List, Optional + +import torch + +from vllm._C import cache_ops +from vllm._C import ops +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention.ops.prefix_prefill import ( + context_attention_fwd) + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +class PagedAttentionImpl: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + ) -> None: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + ) -> torch.Tensor: + output = torch.empty_like(query) + + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = input_metadata.max_context_len <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + alibi_slopes: Optional[torch.Tensor], + ) -> torch.Tensor: + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + alibi_slopes, + ) + return output diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/attention/ops/prefix_prefill.py similarity index 100% rename from vllm/model_executor/layers/triton_kernel/prefix_prefill.py rename to vllm/model_executor/layers/attention/ops/prefix_prefill.py diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 550dec6487f9e..6da0082b94285 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -27,7 +27,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -151,10 +151,10 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) else: self.rotary_emb = get_rope( self.head_dim, @@ -163,8 +163,7 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, self.head_dim, - self.scaling) + self.attn = Attention(self.num_heads, self.head_dim, self.scaling) def forward( self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 4adfb6b78102f..0548b2b140b1b 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -107,10 +107,10 @@ def __init__( alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes) + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) def forward( self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index dca8d724f976b..1c5dcfacaff2b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -10,7 +10,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -87,7 +87,7 @@ def __init__( base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = PagedAttention( + self.attn = Attention( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 6dba952736921..f2dca3df27cfb 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -29,7 +29,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -229,10 +229,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 2b5e022312e3b..3c148be5b10f4 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -28,7 +28,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -150,10 +150,10 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -161,16 +161,16 @@ def __init__( alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.inv_norm_factor, - num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes) else: - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bf1f164ff700d..386a36cf492d6 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -23,7 +23,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import GeluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -123,10 +123,10 @@ def __init__(self, base=self.rope_theta, is_neox_style=True, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 661da0fe0434e..3f7b21e5a4133 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -73,9 +73,7 @@ def __init__( bias=True, linear_method=linear_method, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.scale) + self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index ef4c1d4143c88..5c30d47d93e36 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -26,7 +26,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -85,10 +85,10 @@ def __init__( bias=True, linear_method=linear_method, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.scale, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5bab30d9d442e..b8c6822e9825e 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -24,7 +24,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -86,7 +86,7 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = PagedAttention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, self.head_size, scaling) def forward( self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 8f7e1063e0c1d..98107350e60b9 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -24,7 +24,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -87,7 +87,7 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = PagedAttention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, self.head_size, scaling) def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index ebf1d8a89a022..0ae0a85643456 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -7,7 +7,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -114,10 +114,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d35887cc0f6a3..4c163dfdab537 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,7 +30,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -139,11 +139,11 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) def forward( self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0100624a44d78..d47834e519697 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -197,7 +197,7 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = PagedAttention( + self.attn = Attention( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index a8dadce24aa1d..25c7f1978c0dc 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -32,7 +32,7 @@ from transformers import MixtralConfig from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, ReplicatedLinear, @@ -214,7 +214,7 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = PagedAttention( + self.attn = Attention( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 22a876e2ef691..16ecac3d0529a 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -8,7 +8,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -105,11 +105,11 @@ def __init__( self.head_dim = self.d_model // self.total_num_heads scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scaling, - alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 9d563039208c8..fa7a6d850051e 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -43,7 +43,7 @@ from torch import nn from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearMethodBase, @@ -126,9 +126,9 @@ def __init__( base=rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling) # Attention output projection. self.attn_out = RowParallelLinear( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 393b2dcabcd5a..782f43ce265bd 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -89,9 +89,9 @@ def __init__( bias=bias, linear_method=linear_method, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - scale=self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling) def forward( self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 0b067d4fc8802..6039b1cdc3534 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -12,7 +12,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -118,10 +118,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index d143261968288..039dc7a9b7675 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -43,7 +43,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -108,7 +108,7 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = PagedAttention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, self.head_size, scaling) def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 37af84c7cd53f..d4d5a4e8bb9a5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -12,7 +12,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -104,7 +104,7 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, self.head_dim, self.scaling) def forward( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e823e6f8c3dbe..3586a7fb82778 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,7 +30,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -135,11 +135,11 @@ def __init__(self, max_position=max_position, base=self.rope_theta, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window) def forward( self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 44c57e5a6d4f9..d1a547f815616 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -122,10 +122,10 @@ def __init__(self, max_position=self.config.max_position_embeddings, base=self.config.rope_theta, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_key_value_heads) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads) def forward( self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 1eda07b724cae..efa235233372f 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,7 +25,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -103,7 +103,7 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = PagedAttention( + self.attn = Attention( self.num_heads, self.head_dim, self.scaling,