Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support flash attention 2 with KV's sequence length longer than Q's #2033

Merged
merged 2 commits into from
Aug 8, 2023

Conversation

BoxiangW
Copy link
Contributor

@BoxiangW BoxiangW commented Aug 4, 2023

Implemented this situation with and without causal mask.
My implementation with causal mask looks like:
111000
111100
111110
111111
Where only the right upper triangle part will be masked.
I added P_SEQ for the notation of extra sequence length for KV.

@BoxiangW BoxiangW requested a review from ptillet as a code owner August 4, 2023 03:51
@BoxiangW
Copy link
Contributor Author

BoxiangW commented Aug 4, 2023

Link to issue: #2025

@ptillet ptillet merged commit f21a053 into triton-lang:main Aug 8, 2023
3 of 4 checks passed
@janEbert
Copy link
Contributor

janEbert commented Aug 9, 2023

Should this also be added to python/triton/ops/flash_attention.py?

@BoxiangW
Copy link
Contributor Author

BoxiangW commented Aug 9, 2023

I will try to add it to python/triton/ops/flash_attention.py as well.

pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…r than Q's (triton-lang#2033)

Implemented this situation with and without causal mask.
My implementation with causal mask looks like:
111000
111100
111110
Where only the right upper triangle part will be masked.
I added `P_SEQ` for the notation of extra sequence length for KV.

Co-authored-by: Philippe Tillet <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants