-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support attentions in AlphaFold2 #57
base: main
Are you sure you want to change the base?
Conversation
* add support for attn mask * add mask operation * add mask operation * add mask operation * add interface * add mask support * add mask supprt * fix up * add bias * add template * add test * clean * clean code * add mask load * add mask test * fix forward bugs * add test * add mask in backward * add test case * add bias * add mask * add bias test * fix test case * add without mask test * add kernel test * add ds save * fix interface * add test * fix dbias * add bias support * add mask shape * add test * add support * fix bf16 and mask shape * fix mask head=1 shape * add dump * to fix len 512 * add test * fix seqlen greater than 256 * fix bias seqlen * add constexpr * add const expr for bwd * add benchmark * add test tools * add script * add cross attention * add cross attn * fix bugs * remove test tools * clean fmha_api.cpp * clean fmha_dgrad_fp16_kernel_loop.sm80.cu * clean fmha_dgrad_kernel_1xN_loop.h * clean fmha_fprop_fp16_kernel.sm80.cu * clean fmha_fprop_kernel_1xN.h * cleangmem_tile.h * clean softmax.h * restore test_flash_attn.py * clean gmem_tile.h * fix fmha_fprop_kernel_1xN.h * fix fmha_dgrad_kernel_1xN_loop.h * rename has_attn to has_attn_mask, has_bias to has_attn_bias * fix fmha_fprop_kernel_1xN.h * rename has_attn to has_attn_mask, has_bias to has_attn_bias * remove useless benchmark code * add declaration * remove useless comments * remove useless comments * add timeout * add default timeout for build wheel * remove timeout * reduce build worker for workflow oom
Currently, we implemented the following case for attention bias/mask,
|
Thanks so much for the great work, and congrats on the speedup on Uni-Fold! I'll have more time this weekend to review carefully. |
Great, any suggestions are welcomed. we still have something that needs to refine to make it more applicable.
|
30ddfcc
to
50ca234
Compare
Not worked if mask or bias have odd sequence length. |
@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.
From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch? |
Thank you for your advice. Currently, |
Thanks for the suggestion @tridao . the flatten-non-padding input is not trivial in alphafold2.
|
I see, thanks for explaining, this is very helpful. If this sounds reasonable I'll take a stab at implementing the seqlen_k masking and then rebase and merge the bias part from this PR? |
Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]? |
@tridao Hi, Tridao, sorry for the late reply. Using the 'key padding mask' style is a really good method to reduce code complexity and compilation time. But we checked that the masked keys were not always at the end of the sequence. One case is that the The
Another case is that the So we choose to use the |
b01ad80
to
2e33fc8
Compare
fa580a4
to
4a6eaa9
Compare
@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice? |
@tridao Any update on merging this, or the part to support arbitrary masks and biases? |
I just haven't had time to review and merge it (it's a pretty big change). Still trying to figure out a good way to support both mask and bias without increasing compilation time by 4x. |
Do you mean overflow or nan? And can you provide some shapes of inputs? |
The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):
|
It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list. |
Thank you for the reply. |
* add odd length support * add mask in attn_mask & attn_bias * rm useless files * move if to the outer loop * remove comments --------- Co-authored-by: xhj <[email protected]>
Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet? |
For the padding mask, I think the official repo is already supported. For custom attention mask, we also support some shapes but not for all. |
mask_head_mod_size = mask_sizes[1]; | ||
mask_seq_mod_size = mask_sizes[2]; | ||
TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads); | ||
TORCH_CHECK(mask_sizes[2] == 1 || mask_sizes[2] == max_seqlen_q_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first comment in the conversation indicates that the mask has to be:
[bsz, nh, 1, k_len]
but the code here suggests that this supports the full
[bsz, nh, q_len, k_len]
with broadcast supported for nh
and q_len
@robotcator I have a question about |
e9018eb
to
5400fdc
Compare
Hello, what's up with this PR? Is the code in a usable state? I didn't quite get it from the above discussion. Thanks for you work, awesome job! |
I don't know whether it's too late to reply, actually, the attn_bias's grad of will compute automatically. |
Guys, let's face it. |
For anyone still looking for this see: https://pytorch.org/blog/flexattention/ |
We added the support of (additive) attention_mask and (additive) attention_bias, so that the flash-attention could be used in Evoforomer in Alphafold2. We benchmarked in Uni-Fold, and it achieved a further ~20% speed-up.
Comments and suggestions are very welcome!
some benchmark results:
Training GPU hours:
Inference time and memory cost (one evoformer layer, without chunking):