Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode #4628

Merged
merged 47 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0eb1ab1
flashinfer for prefill
May 6, 2024
4590b46
minor
May 6, 2024
3bfbdf7
fix docker
May 7, 2024
993a4ae
work for prefix caching
May 7, 2024
b4d9dae
dedup test
May 7, 2024
eb2d18e
Merge branch 'main' into flashinfer-prefill
May 28, 2024
5e3d11d
format
May 29, 2024
89f0e2c
fix test
May 29, 2024
72e704b
remove flashinfer from ci
May 30, 2024
f9770ed
wip, cuda graph for decode
Jun 4, 2024
f1849f7
wip
Jun 4, 2024
88425a3
pass tests
Jun 4, 2024
74a8eeb
wip
Jun 5, 2024
dcbbfd6
pass simple tests, need more fix for correctness
LiuXiaoxuanPKU Jun 11, 2024
4302848
optimizer prepare input
LiuXiaoxuanPKU Jun 13, 2024
d739312
padding
LiuXiaoxuanPKU Jun 13, 2024
e5017e2
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 13, 2024
5ad175a
style
LiuXiaoxuanPKU Jun 13, 2024
543dc3b
share workspace buffer to reduce cudagraph extra memory cost
LiuXiaoxuanPKU Jun 14, 2024
11b7347
address comments
LiuXiaoxuanPKU Jun 17, 2024
b5db4be
fix
LiuXiaoxuanPKU Jun 17, 2024
f53d03e
fix comments
LiuXiaoxuanPKU Jun 18, 2024
6fb1b6d
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 18, 2024
e05ff79
support TP > 1
LiuXiaoxuanPKU Jun 19, 2024
8f685dd
try CI
LiuXiaoxuanPKU Jun 20, 2024
0f8e7a1
minor
LiuXiaoxuanPKU Jun 20, 2024
cf275a1
format
LiuXiaoxuanPKU Jun 20, 2024
0ab32ee
minor
LiuXiaoxuanPKU Jun 20, 2024
c421f1f
try CI
LiuXiaoxuanPKU Jun 20, 2024
815efc2
flash attention dependency
LiuXiaoxuanPKU Jun 20, 2024
901b369
minor
LiuXiaoxuanPKU Jun 20, 2024
b2d9895
flash attn
LiuXiaoxuanPKU Jun 21, 2024
df16a6b
format
LiuXiaoxuanPKU Jun 21, 2024
64a24cb
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 21, 2024
dc4e7ef
fix ci
LiuXiaoxuanPKU Jun 22, 2024
9774919
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 23, 2024
aeb0df6
use llama3-8b in test and add warning
LiuXiaoxuanPKU Jun 25, 2024
8a72dcf
fix
LiuXiaoxuanPKU Jun 25, 2024
aaddbad
remove amd tests
LiuXiaoxuanPKU Jun 26, 2024
4aa2069
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 26, 2024
e61bd38
fix
LiuXiaoxuanPKU Jun 27, 2024
b2484df
minor
LiuXiaoxuanPKU Jun 27, 2024
0f4f796
fix
LiuXiaoxuanPKU Jun 28, 2024
3dca2f0
change buffer init
LiuXiaoxuanPKU Jun 28, 2024
7853235
fix ci
LiuXiaoxuanPKU Jun 28, 2024
8316bc3
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 28, 2024
d5348f1
Merge branch 'main' into flashinfer-prefill
LiuXiaoxuanPKU Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref

import pytest
Expand All @@ -13,7 +12,6 @@
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


def test_vllm_gc_ed():
Expand All @@ -39,10 +37,6 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

Expand Down
5 changes: 0 additions & 5 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -39,16 +38,12 @@ def test_models(
) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)

backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
enforce_eager = backend_by_env_var == "FLASHINFER"

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
enforce_eager=enforce_eager,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
Expand Down
83 changes: 58 additions & 25 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type

import flashinfer
try:
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None

import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm_flash_attn import flash_attn_varlen_func

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand Down Expand Up @@ -60,19 +67,16 @@ class FlashInferMetadata(AttentionMetadata):
# requests only.
max_prefill_seq_len: int

use_cuda_graph: bool = False
use_cuda_graph: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: can we make it not kwarg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm following other backends such as

use_cuda_graph: bool


prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None

# Metadata for the prefill stage since we still
# use flash attention for prefill.
# Metadata for the prefill stage
seq_start_loc: Optional[torch.Tensor] = None
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None

# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
Expand All @@ -98,6 +102,7 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
device: torch.device = torch.device("cuda")

def __post_init__(self):
# Refer to
Expand All @@ -109,13 +114,34 @@ def __post_init__(self):
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")

# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return

assert self.prefill_wrapper is not None
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.prefill_wrapper.begin_forward(
self.query_start_loc, self.paged_kv_indptr,
self.paged_kv_indices, self.paged_kv_last_page_len,
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.begin_forward(
self.paged_kv_indptr,
self.paged_kv_indices,
Expand All @@ -133,8 +159,9 @@ def asdict_zerocopy(self,
) -> Dict[str, Any]:
if skip_fields is None:
skip_fields = set()
# We need to skip the decode_wrapper field since it cannot be
# We need to skip the prefill/decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields.add('prefill_wrapper')
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)

Expand Down Expand Up @@ -168,6 +195,7 @@ def __init__(
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down Expand Up @@ -217,10 +245,14 @@ def forward(
self.kv_cache_dtype,
)

query = query.contiguous(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we discussed this before, but what's the overhead of this call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For llama7b on A100, the shape of query is [256, 12, 64], and this line takes ~0.037ms.

) # Flashinfer requires query to be contiguous
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.block_tables is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# We will use flash attention for prefill
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is that? Can you comment? (also is it fundamental?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only happens during the profiling phase, where the cache is initialized (not paged). We use flash attention for the profile run.

# when kv_cache is not provided.
# This happens when vllm runs the profiling to
# determine the number of blocks.
if kv_cache is None:
output = flash_attn_varlen_func(
q=query,
k=key,
Expand All @@ -235,13 +267,14 @@ def forward(
alibi_slopes=self.alibi_slopes,
)
else:
raise NotImplementedError(
"Prefix caching is not supported with flashinfer yet.")
assert prefill_meta is not None
assert prefill_meta.prefill_wrapper is not None
output = prefill_meta.prefill_wrapper.forward(query,
kv_cache,
causal=True)
else:
assert attn_metadata.decode_metadata is not None
assert attn_metadata.decode_metadata.decode_wrapper is not None
query = query.contiguous(
) # Flashinfer requires query to be contiguous
output = attn_metadata.decode_metadata.decode_wrapper.forward(
query,
kv_cache,
Expand Down
2 changes: 0 additions & 2 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ def get_attn_backend(
return TorchSDPABackend
elif backend == _Backend.FLASHINFER:
logger.info("Using Flashinfer backend.")
logger.warning("Eager mode is required for the Flashinfer backend. "
"Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend
elif backend == _Backend.PALLAS:
Expand Down
Loading
Loading