Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support attentions in AlphaFold2 #57

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

guolinke
Copy link

@guolinke guolinke commented Oct 13, 2022

We added the support of (additive) attention_mask and (additive) attention_bias, so that the flash-attention could be used in Evoforomer in Alphafold2. We benchmarked in Uni-Fold, and it achieved a further ~20% speed-up.

Comments and suggestions are very welcome!

some benchmark results:

Training GPU hours:
img_v2_14099b19-6c86-42a0-ae3b-25630fe4fbfg

Inference time and memory cost (one evoformer layer, without chunking):
image

* add support for attn mask

* add mask operation

* add mask operation

* add mask operation

* add interface

* add mask support

* add mask supprt

* fix up

* add bias

* add template

* add test

* clean

* clean code

* add mask load

* add mask test

* fix forward bugs

* add test

* add mask in backward

* add test case

* add bias

* add mask

* add bias test

* fix test case

* add without mask test

* add kernel test

* add ds save

* fix interface

* add test

* fix dbias

* add bias support

* add mask shape

* add test

* add support

* fix bf16 and mask shape

* fix mask head=1 shape

* add dump

* to fix len 512

* add test

* fix seqlen greater than 256

* fix bias seqlen

* add constexpr

* add const expr for bwd

* add benchmark

* add test tools

* add script

* add cross attention

* add cross attn

* fix bugs

* remove test tools

* clean fmha_api.cpp

* clean fmha_dgrad_fp16_kernel_loop.sm80.cu

* clean fmha_dgrad_kernel_1xN_loop.h

* clean fmha_fprop_fp16_kernel.sm80.cu

* clean fmha_fprop_kernel_1xN.h

* cleangmem_tile.h

* clean softmax.h

* restore test_flash_attn.py

* clean gmem_tile.h

* fix fmha_fprop_kernel_1xN.h

* fix fmha_dgrad_kernel_1xN_loop.h

* rename has_attn to has_attn_mask, has_bias to has_attn_bias

* fix fmha_fprop_kernel_1xN.h

* rename has_attn to has_attn_mask, has_bias to has_attn_bias

* remove useless benchmark code

* add declaration

* remove useless comments

* remove useless comments

* add timeout

* add default timeout for build wheel

* remove timeout

* reduce build worker for workflow oom
@guolinke guolinke mentioned this pull request Oct 13, 2022
@robotcator
Copy link
Contributor

Currently, we implemented the following case for attention bias/mask,

Support the shape of q/k/v as follow:
q's shape [total_size * head, seq_q, head_dim]
k's shape [total_size * head, seq_k, head_dim]
v's shape [total_size * head, seq_k, head_dim]
Attention Mask 
[total_size, head, seq_q, seq_k]
1. total_size must be the same as q's total_size
2. head must be 1 or head like shape in q
3. seq_q must be 1  
4. seq_k must be the same as k's seq_k 
Attention Bias
[total_size, head, seq_q, seq_k]
1. total_size must be 1
2. head must be the same as q's head
3. seq_q must be the same as q's seq_q
4. seq_k must be the same as k's seq_k

@tridao
Copy link
Contributor

tridao commented Oct 19, 2022

Thanks so much for the great work, and congrats on the speedup on Uni-Fold!

I'll have more time this weekend to review carefully.

@robotcator
Copy link
Contributor

Thanks so much for the great work, and congrats on the speedup on Uni-Fold!

I'll have more time this weekend to review carefully.

Great, any suggestions are welcomed. we still have something that needs to refine to make it more applicable.

  1. Fixing the interface incompatible in flash_attn_interface.py
  2. Adding our unit test for the mask and bias interface.
  3. Adding the odd length of mask/bias in the last shape.

@tridao tridao force-pushed the main branch 3 times, most recently from 30ddfcc to 50ca234 Compare October 24, 2022 00:26
@reymondzzzz
Copy link

Not worked if mask or bias have odd sequence length. CUDA error (/tmp/pip-req-build-k5fpgkes/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu:140): misaligned address

@tridao
Copy link
Contributor

tridao commented Nov 6, 2022

@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.

Attention Mask
[total_size, head, seq_q, seq_k]

  1. total_size must be the same as q's total_size
  2. head must be 1 or head like shape in q
  3. seq_q must be 1
  4. seq_k must be the same as k's seq_k

From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch?
If the key-padding mask doesn't change across layer then the most performant way to do it is to remove padding before the first layer (we have a function unpad_input), run through all the layers, then optionally add back the padding tokens.
Is my understanding correct?

@robotcator
Copy link
Contributor

Not worked if mask or bias have odd sequence length. CUDA error (/tmp/pip-req-build-k5fpgkes/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu:140): misaligned address

Thank you for your advice. Currently, Adding the odd length of mask/bias in the last shape is in our progress list.

@guolinke
Copy link
Author

guolinke commented Nov 7, 2022

@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.

Attention Mask
[total_size, head, seq_q, seq_k]

  1. total_size must be the same as q's total_size
  2. head must be 1 or head like shape in q
  3. seq_q must be 1
  4. seq_k must be the same as k's seq_k

From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch? If the key-padding mask doesn't change across layer then the most performant way to do it is to remove padding before the first layer (we have a function unpad_input), run through all the layers, then optionally add back the padding tokens. Is my understanding correct?

Thanks for the suggestion @tridao . the flatten-non-padding input is not trivial in alphafold2.

  1. there are 2 representations (token-level and pair-level), and 4 kinds of attention, in which the mask/bias Evoformer are different.
  2. the 2 representations are communicated at each Evoformer layer, and the shape is better in the padding form for the computation.

@tridao
Copy link
Contributor

tridao commented Nov 7, 2022

the flatten-non-padding input is not trivial in alphafold2.

I see, thanks for explaining, this is very helpful.
How about we pass in a tensor (type int) with the sequence lengths of the key for each batch? That might be faster (we read 1 int instead of one vector of mask) and simpler (reduce code complexity and compilation time).
Would this work for the alphafold2 use case?

If this sounds reasonable I'll take a stab at implementing the seqlen_k masking and then rebase and merge the bias part from this PR?

@tridao
Copy link
Contributor

tridao commented Nov 7, 2022

Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]?
That is, are the masked keys always at the end of the sequence?

@robotcator
Copy link
Contributor

Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]? That is, are the masked keys always at the end of the sequence?

@tridao Hi, Tridao, sorry for the late reply. Using the 'key padding mask' style is a really good method to reduce code complexity and compilation time. But we checked that the masked keys were not always at the end of the sequence.

One case is that the gen_msa_attn_mask function in here will generate two types of mask, i.e. row_mask, col_mask.

The row_mask was generated from the original mas_mask tensor and the col_mask was generated from the transpose of the mas_mask tensor. So the col_mask tensor's masked keys were not at the end of the sequence. The minimal example is as follows.

# the original  `mas_mask` tensor.
tensor([[0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf]])

# the transpose of the `mas_mask` tensor.
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [-inf, 0.],
        [-inf, -inf],
        [-inf, -inf]])

Another case is that the mas_mask's masked keys were not always padding at end of the sequence, there will be at any position in the sequence.

So we choose to use the attention mask rather than the key padding mask style method. If you have any confusion, please free to contact us. We also suffer from the compilation time problem, hope we can find some method to tackle it.

@logicwong
Copy link

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

@rahul003
Copy link

rahul003 commented Jan 4, 2023

@tridao Any update on merging this, or the part to support arbitrary masks and biases?

@tridao
Copy link
Contributor

tridao commented Jan 4, 2023

I just haven't had time to review and merge it (it's a pretty big change). Still trying to figure out a good way to support both mask and bias without increasing compilation time by 4x.

@robotcator
Copy link
Contributor

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

@logicwong
Copy link

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this:
215360A7-4B8F-4D45-8B47-7E80F7553922

The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):

def attention(q, k, v, attn_bias, seq_len)
	# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)

	cu_seqlens = torch.arange(
	    0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
	)
	attn = flash_attn_unpadded_func(
	    q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
	    attn_mask=None, attn_bias=attn_bias,
	    dropout_p=0.0,
	    softmax_scale=1.0, causal=False
	)

@robotcator
Copy link
Contributor

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this: 215360A7-4B8F-4D45-8B47-7E80F7553922

The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):

def attention(q, k, v, attn_bias, seq_len)
	# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)

	cu_seqlens = torch.arange(
	    0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
	)
	attn = flash_attn_unpadded_func(
	    q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
	    attn_mask=None, attn_bias=attn_bias,
	    dropout_p=0.0,
	    softmax_scale=1.0, causal=False
	)

It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list.

@logicwong
Copy link

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this: 215360A7-4B8F-4D45-8B47-7E80F7553922
The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):

def attention(q, k, v, attn_bias, seq_len)
	# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)

	cu_seqlens = torch.arange(
	    0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
	)
	attn = flash_attn_unpadded_func(
	    q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
	    attn_mask=None, attn_bias=attn_bias,
	    dropout_p=0.0,
	    softmax_scale=1.0, causal=False
	)

It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list.

Thank you for the reply.

* add odd length support

* add mask in attn_mask & attn_bias

* rm useless files

* move if to the outer loop

* remove comments

---------

Co-authored-by: xhj <[email protected]>
@subercui
Copy link

subercui commented Apr 9, 2023

Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet?

@robotcator
Copy link
Contributor

Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet?

For the padding mask, I think the official repo is already supported. For custom attention mask, we also support some shapes but not for all.

mask_head_mod_size = mask_sizes[1];
mask_seq_mod_size = mask_sizes[2];
TORCH_CHECK(mask_sizes[1] == 1 || mask_sizes[1] == num_heads);
TORCH_CHECK(mask_sizes[2] == 1 || mask_sizes[2] == max_seqlen_q_);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first comment in the conversation indicates that the mask has to be:
[bsz, nh, 1, k_len]
but the code here suggests that this supports the full
[bsz, nh, q_len, k_len]
with broadcast supported for nh and q_len

@Birdylx
Copy link

Birdylx commented Aug 15, 2023

@robotcator I have a question about attn_bias, if my attn_bias is trainable, does flash attn will compute grad of attn_bias automatically ?

@nikita-petrashen
Copy link

Hello, what's up with this PR? Is the code in a usable state? I didn't quite get it from the above discussion. Thanks for you work, awesome job!

@robotcator
Copy link
Contributor

@robotcator I have a question about attn_bias, if my attn_bias is trainable, does flash attn will compute grad of attn_bias automatically ?

I don't know whether it's too late to reply, actually, the attn_bias's grad of will compute automatically.

@nofreewill42
Copy link

Guys, let's face it.
It's like there is a hidden force not allowing this one to go through.
Someone is gatekeeping

@maxall41
Copy link

For anyone still looking for this see: https://pytorch.org/blog/flexattention/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.