-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash-attention-v2 triton version (adding bias) #2029
Comments
probably @ptillet can take a look at it? |
Have you looked at this implementation? I haven't checked its correctness myself though but the comments about bugs in some head_dim regimes indicates it's been tested |
let me check! |
Looks like this one from mosaic uses the same implementation |
Is there any progress? I also want to implement flash2+bias.
|
@tiandiao123 Hi?Is there any update? I currently want to implement this version too. |
@shiqingzhangCSU why not just use the one from mosaic? |
I saw lightllm has some updates: https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L10 |
Mark, important feature |
I've written a version of this here: https://github.com/alexzhang13/flashattention2-custom-mask. To add arbitrary attention biases you just need to remove the masking logic (torch.where). If it's still needed, I can write out this explicit functionality as well. |
Hello friends:
I am wondering whether someone here can help check my modified version of fused_attention_bias(https://gist.github.com/tiandiao123/0b82ea31a5dc5865663c2966e369b05a#file-flash_attention_bias-py-L106). I am trying to use triton and original example in tutorial to modify original flash attention algorithm. In my opinion,I can only simply add a bias_ptr, and load corresponding location into SRAM, and then add them into qk, but it is not what I expected in the result since the output is not exactly matched with my pytorch implementation after testing it. someone has some ideas?
The text was updated successfully, but these errors were encountered: