Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SpecDecode] Support FlashInfer in DraftModelRunner #6926

Merged
merged 13 commits into from
Aug 5, 2024

Conversation

bong-furiosa
Copy link
Contributor

FILL IN THE PR DESCRIPTION HERE

FIX #6885

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE

This PR resolves the benchmark_serving.py error for Speculative Decoding when using FlashInfer Backend.

After fixing, we can see that the benchmark results are correctly displayed in both FlashAttn Backend and FlashInfer Backend.

FLASH_ATTN backend

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  22.43     
Total input tokens:                      26322     
Total generated tokens:                  11433     
Request throughput (req/s):              4.46      
Input token throughput (tok/s):          1173.60   
Output token throughput (tok/s):         509.76    
---------------Time to First Token----------------
Mean TTFT (ms):                          54.68     
Median TTFT (ms):                        37.58     
P99 TTFT (ms):                           221.78    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.23     
Median TPOT (ms):                        34.16     
P99 TPOT (ms):                           158.26    
---------------Inter-token Latency----------------
Mean ITL (ms):                           70.08     
Median ITL (ms):                         55.27     
P99 ITL (ms):                            363.48    
==================================================

FLASHINFER Backend

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  21.89     
Total input tokens:                      26322     
Total generated tokens:                  11425     
Request throughput (req/s):              4.57      
Input token throughput (tok/s):          1202.55   
Output token throughput (tok/s):         521.97    
---------------Time to First Token----------------
Mean TTFT (ms):                          68.02     
Median TTFT (ms):                        57.64     
P99 TTFT (ms):                           318.30    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.91     
Median TPOT (ms):                        32.33     
P99 TPOT (ms):                           177.95    
---------------Inter-token Latency----------------
Mean ITL (ms):                           66.46     
Median ITL (ms):                         50.02     
P99 ITL (ms):                            346.13    
==================================================

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@bong-furiosa
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 30, 2024
@comaniac comaniac changed the title [Bugifx] fixed draft_model_runner.py #6885 [SpecDecode] Support FlashInfer in DraftModelRunner Jul 30, 2024
@comaniac
Copy link
Collaborator

Thanks for the PR! A question I have is we currently have a guard to make sure only FlashAttention can use draft model runner, so if you use FlashInfer, vLLM shouldn't use draft model runner. I'm curious how this PR could fix the issue you mentioned.

@comaniac comaniac removed the ready ONLY add when PR is ready to merge/full CI is needed label Jul 30, 2024
@bong-furiosa
Copy link
Contributor Author

@comaniac Thank you for your interest on this PR!

Please let me explain the purpose of this PR first, and then I'll describe the issues it addresses. 🙇

Purpose of this PR

Indeed, the current Speculative Decoding in vLLM selects the FlashAttention as a default backend.
As far as I know, to enable the FlashInfer backend, we need to manually set the environment variable VLLM_ATTENTION_BACKEND=FLASHINFER as mentioned in #4353 (link)

However, I found that recently @LiuXiaoxuanPKU has been working on integrating FlashInfer as a backend in vLLM and extending support to CUDAGraph for general LLM inference (#4353, #4628).

Additionally, according to the Speculative Decoding (with Batch Expansion, I guess...) experiments conducted by @cadedaniel, it was confirmed that using FlashInfer resulted in lower latency compared to other techniques.

I anticipate that FlashInfer will be used as the vLLM Speculative Decoding backend in the future. Although, implementing FlashInfer with CUDA Graph seems really challenging.

Therefore, I did PR modifying the code to enable the use of FlashInfer as the backend for vLLM Speculative Decoding in the current v0.5.3.post1. I believe this addresses a bug (or oversight?) that has been present since draft_model_runner was first introduced in v0.5.1. 🤔

Addressed issue

The solution of the issue was simple. I copied the necessary "flashinfer" related code blocks from the ModelRunner class in model_runner.py directly into draft_model_runner.py.

  1. As mentioned in [Bug]: Speculative Decoding + FlashInfer + benchmark_serving.py TransferEncodingError ISSUE #6885 , this PR addresses the issue where vLLM Speculative Decoding did not work when the FlashInfer backend was activated.

  2. In spec_decode_worker.py, the acceptance rates for the FlashAttention backend and the FlashInfer backend were recorded as 0.42 and 0.41, respectively (manually calculated in SpecDecodeWorker._run_speculative_decoding_step).

Although it may take a considerable amount of time for FlashInfer to be integrated officially into the Speculative Decoding in vLLM, I believe this PR addresses bugs that could arise during that process.

If the answer is not what you were looking for, please feel free to ask additional questions or close this PR. 🙇

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM. Can you add a unit test?

Comment on lines 14 to 17
# (bong-furiosa)
# Resolve the issue of the wrapper variable
# not being defined during the execution of
# FlashInfer backend Speculative Decoding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not necessary as the code itself is straightforward.

Comment on lines 96 to 99
# (bong-furiosa)
# Resolve the issue of the wrapper variable
# not being defined during the execution of
# FlashInfer backend Speculative Decoding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Unnecessary comments.

Comment on lines 312 to 315
# (bong-furiosa)
# Resolve the issue of the wrapper variable
# not being defined during the execution of
# FlashInfer backend Speculative Decoding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto. Unnecessary comments.

Comment on lines 316 to 345
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")

model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in the "else" branch above, because the if branch will always use flash attention backend.

@bong-furiosa
Copy link
Contributor Author

@comaniac Thank you for the advice on refining the code!
I would appreciate it if you could check the refined code.
Also, it looks like there have been many updates to vLLM since the current PR.
If any issues arise with the PR, I will sync and PR again.


Unfortunately, I do not have the time or ability to write unit test for the modified code...
I assume that using os.environ['VLLM_ATTENTION_BACKEND'] to test offline Spec Decoding in FLASH_ATTN or FLASHINFER backend would suffice. However, I am concerned about my ability to write the unit tests correctly. 😢

Therefore, would it be okay to expect that vLLM experts or other users write and confirm the unit test code? 🤔 🙇

@bong-furiosa
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 4, 2024
@bong-furiosa
Copy link
Contributor Author

@comaniac Sorry for being late. I kept getting a failure log while checking the entrypoints(?) in fastcheck.
I would appreciate it if you could review the pr content and any additional advice when you have time. 🙇 🙇

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge this PR first given that it doesn't fail existing cases. I'll try to find a time to add a unit test later.

@comaniac comaniac merged commit e963045 into vllm-project:main Aug 5, 2024
66 checks passed
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Speculative Decoding + FlashInfer + benchmark_serving.py TransferEncodingError ISSUE
2 participants