-
Notifications
You must be signed in to change notification settings - Fork 26.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add FA2 for all Bart-like #26722
Conversation
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
Verified that FA2 works by checking Whisper. Bart's attention is exactly the same as Whisper so it should as well. I will run some better benchmarks later. @ArthurZucker @younesbelkada could you do a first review here, just for the following files:
It would be nice to agree on these files before running Some comments:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good yeah!
regarding your comments, I totally agree regarding the padding mask, this was my initial concern here. Llama needed less tolerance but let's update it. Otherwise Looks good, let's make sure the attention is as clean as possible as it will be the reference for cross attention.
|
||
def _flash_attention_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _flash_attention_forward( | |
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward | |
def _flash_attention_forward( |
this is copied from as well no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we need a new causal
function argument here to differentiate between non-causal (encoder) and causal (decoder) attention
@@ -2797,16 +2797,35 @@ def test_flash_attn_2_inference(self): | |||
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might need decoder input ids as well for cross attention testing?
Update: The PR now works for Bart. @ArthurZucker @younesbelkada @fxmarty @LysandreJik could you give the design chosen for BART a look here and if ok, I'll apply it to all other Bart-like models. Please only review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! There are probably some docstring (e.g. BartDecoderLayer
) whose attention_mask
doc should be modified accordingly.
@@ -148,7 +276,9 @@ def __init__( | |||
num_heads: int, | |||
dropout: float = 0.0, | |||
is_decoder: bool = False, | |||
is_causal: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personal taste, but I would add this arg after bias
in case somebody is using positional arguments.
# BartFlashAttention2 attention does not support output_attentions | ||
output_attentions = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know how it was for llama, but I would raise an error here in case output_attentions
is True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes fair, problem though is that by now it's backwards breaking
# TODO: Bart does not have dropout in the config?? | ||
# It is recommended to use dropout with FA according to the docs | ||
# when training. | ||
dropout_rate = 0.0 # if not self.training else self.attn_dropout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Bart has some dropout:
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only reviewed the bart modelling file, looks good overall!
If the attention logic happens in the attention class might be slightly better, but otherwise it nice that we expose in a better way how attention masks need to be processed! Thanks.
if attention_mask is not None: | ||
attention_mask = self.causal_attn_mask_converter.to_4d( | ||
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype | ||
) | ||
else: | ||
attention_mask = self.causal_attn_mask_converter.to_causal_4d( | ||
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice if the to_causal_4d
supports feeding a mask and takes care of this if else no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but the mask is None here and then I need to pass all these shapes anyways
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) | ||
if getattr(self.config, "_flash_attn_2_enabled", False): | ||
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask is not None and 0 in encoder_attention_mask) else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think a comment would be nice to say why we don't pass the mask to FA if there is not 0 values in it.
if getattr(config, "_flash_attn_2_enabled", False): | ||
self.encoder_attn = BartFlashAttention2( | ||
embed_dim=self.embed_dim, | ||
num_heads=config.decoder_attention_heads, | ||
dropout=config.attention_dropout, | ||
is_decoder=True, | ||
config=config, | ||
) | ||
else: | ||
self.encoder_attn = BartAttention( | ||
embed_dim=self.embed_dim, | ||
num_heads=config.decoder_attention_heads, | ||
dropout=config.attention_dropout, | ||
is_decoder=True, | ||
config=config, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think a mapping BERT_ATTENTIONS["attention_class"]
will be cleaner long term if we add sdpa, flash decoding etc, specifically given that the init arguments are consistent (and we want this to always be the case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure that makes sense!
What does this PR do?
Add FA2 to all Bart-like models