Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash-attention-v2 triton version (adding bias) #2029

Open
tiandiao123 opened this issue Aug 3, 2023 · 10 comments
Open

Flash-attention-v2 triton version (adding bias) #2029

tiandiao123 opened this issue Aug 3, 2023 · 10 comments

Comments

@tiandiao123
Copy link

tiandiao123 commented Aug 3, 2023

Hello friends:
I am wondering whether someone here can help check my modified version of fused_attention_bias(https://gist.github.com/tiandiao123/0b82ea31a5dc5865663c2966e369b05a#file-flash_attention_bias-py-L106). I am trying to use triton and original example in tutorial to modify original flash attention algorithm. In my opinion,I can only simply add a bias_ptr, and load corresponding location into SRAM, and then add them into qk, but it is not what I expected in the result since the output is not exactly matched with my pytorch implementation after testing it. someone has some ideas?

@tiandiao123
Copy link
Author

probably @ptillet can take a look at it?

@chaaland
Copy link

chaaland commented Aug 9, 2023

Have you looked at this implementation? I haven't checked its correctness myself though but the comments about bugs in some head_dim regimes indicates it's been tested

@tiandiao123
Copy link
Author

Have you looked at this implementation? I haven't checked its correctness myself though but the comments about bugs in some head_dim regimes indicates it's been tested

let me check!

@chaaland
Copy link

Looks like this one from mosaic uses the same implementation

@shiqingzhangCSU
Copy link

Is there any progress? I also want to implement flash2+bias.

Hello friends: I am wondering whether someone here can help check my modified version of fused_attention_bias(https://gist.github.com/tiandiao123/0b82ea31a5dc5865663c2966e369b05a#file-flash_attention_bias-py-L106). I am trying to use triton and original example in tutorial to modify original flash attention algorithm. In my opinion,I can only simply add a bias_ptr, and load corresponding location into SRAM, and then add them into qk, but it is not what I expected in the result since the output is not exactly matched with my pytorch implementation after testing it. someone has some ideas?

@shiqingzhangCSU
Copy link

@tiandiao123 Hi?Is there any update? I currently want to implement this version too.

@chaaland
Copy link

@shiqingzhangCSU why not just use the one from mosaic?

@tiandiao123
Copy link
Author

@tiandiao123 Hi?Is there any update? I currently want to implement this version too.

I saw lightllm has some updates: https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L10

@juntang-zhuang
Copy link

Mark, important feature

@alexzhang13
Copy link

I've written a version of this here: https://github.com/alexzhang13/flashattention2-custom-mask. To add arbitrary attention biases you just need to remove the masking logic (torch.where).

If it's still needed, I can write out this explicit functionality as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants