Skip to content

Commit

Permalink
[Kernel] Flashinfer for prefill & decode, with Cudagraph support for …
Browse files Browse the repository at this point in the history
…decode (vllm-project#4628)

Co-authored-by: LiuXiaoxuanPKU <[email protected]>, bong-furiosa <[email protected]>
  • Loading branch information
2 people authored and prashantgupta24 committed Jul 1, 2024
1 parent a6b188d commit 96e23ec
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 117 deletions.
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,6 @@ steps:
- pytest -v -s distributed/test_custom_all_reduce.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ sentence-transformers # required for embedding
aiohttp

# quantization
bitsandbytes==0.42.0
bitsandbytes==0.42.0
6 changes: 0 additions & 6 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref

import pytest
Expand All @@ -13,7 +12,6 @@
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


def test_vllm_gc_ed():
Expand All @@ -39,10 +37,6 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

Expand Down
5 changes: 0 additions & 5 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -39,16 +38,12 @@ def test_models(
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
enforce_eager = backend_by_env_var == "FLASHINFER"

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
Expand Down
83 changes: 58 additions & 25 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type

import flashinfer
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None

import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down Expand Up @@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
max_prefill_seq_len: int

use_cuda_graph: bool = False
use_cuda_graph: bool = True

prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

# Metadata for the prefill stage since we still
# use flash attention for prefill.
# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None

# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
Expand All @@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")

def __post_init__(self):
# Refer to
Expand All @@ -109,13 +113,35 @@ def __post_init__(self):
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")

# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
def begin_forward(self):
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return

assert self.prefill_wrapper is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
Expand All @@ -133,8 +159,9 @@ def asdict_zerocopy(self,
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)

Expand Down Expand Up @@ -168,6 +195,7 @@ def __init__(
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -217,10 +245,14 @@ def forward(
self.kv_cache_dtype,
)

query = query.contiguous(
) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# We will use flash attention for prefill
# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
q=query,
k=key,
Expand All @@ -235,13 +267,14 @@ def forward(
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def get_attn_backend(
return IpexAttnBackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
" please avoid using Flashinfer as the"
"backend when running on llma-2-7b."))
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:
Expand Down
Loading

0 comments on commit 96e23ec

Please sign in to comment.