Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Avoid cuda sync in postprocess of LLM decoding #9011

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from

Conversation

lszxb
Copy link
Contributor

@lszxb lszxb commented Aug 26, 2024

PR types

Performance optimization

PR changes

Others

Description

目前在LLM decoding每一个解码步的后处理阶段有大量零碎的算子,导致CPU侧存在较大的kernel launch开销,GPU利用率较低。一般情况下CUDA kernel异步执行的特性会让CPU侧可以提前进行kernel launch,掩盖掉这一部分开销,但目前的实现下后处理阶段存在许多会导致CUDA同步的操作,无法掩盖这一部分开销,导致性能下降。如图,红框中存在大量的零碎kernel,GPU利用率低:
image
本PR通过对后处理阶段的代码进行改写,避免了CUDA的同步操作。经过在llama模型上的测试,能够带来约3%的性能提升。如图,红框部分的零碎kernel由于避免了CPU端kernel launch的开销,耗时大大减小:
image

本PR主要发现并修正的会导致同步的操作包括:

  • 特定形式的paddle.full调用
    用以下形式调用full相关算子会导致阻塞的cudaMemcpy。
    paddle.full(shape=[probs.shape[0], 1], fill_value=top_p, dtype=probs.dtype)
    paddle.ones([attention_mask.shape[0], 1], dtype="int64")
    
    目前的函数中主要有两处,在top_p_sampling前,以及update_model_kwargs_for_generation中对attention_mask的更新。
    • 将以上调用改成paddle.empty+赋值即可解决问题,但由于update_model_kwargs_for_generation函数可以由各个模型自身进行实现,无法简单地一次性修改用法。
    • 问题的根源在于constant_folding_pass错误地将[probs.shape[0], 1]中的“1”视为了parameter,而不是constant,而parameter会存储在GPU memory中,故每次调用都会导致D2H copy。暂时的workaround是在前面提前进行一次paddle.empty(shape=[1])之类调用,让“1”事先正确折叠为constant。
      • paddle.full接受的shape attribute是一个IntArray,要得到这个IntArray,会对[probs.shape[0], 1]进行combine后再进行pd_op.stack操作。当前的常量折叠逻辑是检测“1”或“1”进行combine之后的结果是否是某个算子的attribute输入,若是,则作为constant,否则作为parameter。但却没有考虑pd_op.stack后再作为attribute输入的情况,导致“1”被误认为是parameter。而调用paddle.empty不会产生pd_op.stack,可以被正确处理。
  • if语句。位于top_k=1时的fast path判断。top_k == 1的结果也在GPU上,需要进行device to host数据拷贝,引发了同步。目前直接去掉了这个判断,直接执行else分支结果也是对的。
  • paddle.multinomial
    会导致Pageable cudaMemcpyAsync造成同步。目前改为了使用Gumbel-Max trick进行采样的等价实现。
  • TopKProcess中的处理逻辑。to_tensor会导致同步,改为了clip,并将top_k改为了CPU Tensor。

Copy link

paddle-bot bot commented Aug 26, 2024

Thanks for your contribution!

Copy link

codecov bot commented Aug 26, 2024

Codecov Report

Attention: Patch coverage is 28.57143% with 5 lines in your changes missing coverage. Please review.

Project coverage is 52.78%. Comparing base (81f5ab5) to head (53e229a).
Report is 1 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/generation/utils.py 0.00% 4 Missing ⚠️
paddlenlp/generation/logits_process.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9011      +/-   ##
===========================================
- Coverage    52.90%   52.78%   -0.13%     
===========================================
  Files          661      661              
  Lines       107069   106948     -121     
===========================================
- Hits         56650    56453     -197     
- Misses       50419    50495      +76     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ZHUI ZHUI changed the title Avoid cuda sync in postprocess of LLM decoding [Performance] Avoid cuda sync in postprocess of LLM decoding Sep 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant