Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add FA2 for all Bart-like #26722

Closed
wants to merge 31 commits into from
Closed

[WIP] Add FA2 for all Bart-like #26722

wants to merge 31 commits into from

Conversation

patrickvonplaten
Copy link
Contributor

What does this PR do?

Add FA2 to all Bart-like models

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@patrickvonplaten patrickvonplaten changed the title [WIP] Add FA2 for all Bart-like Add FA2 for all Bart-like Oct 10, 2023
@patrickvonplaten
Copy link
Contributor Author

Verified that FA2 works by checking Whisper. Bart's attention is exactly the same as Whisper so it should as well. I will run some better benchmarks later.

@ArthurZucker @younesbelkada could you do a first review here, just for the following files:

  • tests/test_modeling_common.py
  • src/transformers/models/bart/modeling_bart.py
    [ignore Whisper completely now, please]

It would be nice to agree on these files before running make fix-copies which will change 10+ other modeling files.
The implementation was pretty straight-forward as I can more or less copy-paste all the code from Llama (nice job @younesbelkada!)

Some comments:

  • 1.) The flash attention tests are very nicely implemented in tests/test_modeling_common.py, but it looks like the tolerance is too high to catch any incorrect masking or other settings. E.g. in the beginning for BART I had an incorrect scaling factor in the attention and all tests passed anyways here. We might want to look into this.
  • 2.) I'm not super happy about passing around both attention_mask and padding_mask all the time. This makes the code really difficult to read and is quite confusing (what is the difference between the two?!). As far as I understand it the two masks are the same - the only reason we use padding_mask in addition to attention_mask is because the attention_mask is expanded and thus can't be used for FA. I wonder whether we should do a bigger refactor here actually and instead of expanding the attention_mask in the beginning we only expand it right before the attention so that we don't have to pass around both padding_mask and attention_mask. We could even cache the expanded mask if needed for speed. I would be strongly in favor of not having both a padding mask and an attention mask. Also cc @fxmarty here.
  • 3.) Can we make the automatic conversion to intended FA precision a bit more robust. E.g. see here: https://github.com/huggingface/transformers/pull/26722/files#r1353462371 . Aren't there use cases where the user would like to train in bfloat16, but might have a layer norm in fp32? cc @younesbelkada

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good yeah!
regarding your comments, I totally agree regarding the padding mask, this was my initial concern here. Llama needed less tolerance but let's update it. Otherwise Looks good, let's make sure the attention is as clean as possible as it will be the reference for cross attention.

src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
src/transformers/models/bart/modeling_bart.py Show resolved Hide resolved
Comment on lines 422 to 423

def _flash_attention_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _flash_attention_forward(
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(

this is copied from as well no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we need a new causal function argument here to differentiate between non-causal (encoder) and causal (decoder) attention

src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
@@ -2797,16 +2797,35 @@ def test_flash_attn_2_inference(self):
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might need decoder input ids as well for cross attention testing?

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Oct 24, 2023

Update:

The PR now works for Bart. @ArthurZucker @younesbelkada @fxmarty @LysandreJik could you give the design chosen for BART a look here and if ok, I'll apply it to all other Bart-like models.

Please only review modeling_bart.py !!!

fxmarty
fxmarty previously approved these changes Oct 25, 2023
Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! There are probably some docstring (e.g. BartDecoderLayer) whose attention_mask doc should be modified accordingly.

@@ -148,7 +276,9 @@ def __init__(
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
is_causal: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal taste, but I would add this arg after bias in case somebody is using positional arguments.

Comment on lines +446 to +447
# BartFlashAttention2 attention does not support output_attentions
output_attentions = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know how it was for llama, but I would raise an error here in case output_attentions is True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes fair, problem though is that by now it's backwards breaking

Comment on lines +498 to +501
# TODO: Bart does not have dropout in the config??
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Bart has some dropout:

attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

@fxmarty fxmarty dismissed their stale review October 25, 2023 07:42

wanted to comment rather than approve

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only reviewed the bart modelling file, looks good overall!
If the attention logic happens in the attention class might be slightly better, but otherwise it nice that we expose in a better way how attention masks need to be processed! Thanks.

src/transformers/models/bart/modeling_bart.py Show resolved Hide resolved
Comment on lines +1424 to +1431
if attention_mask is not None:
attention_mask = self.causal_attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = self.causal_attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice if the to_causal_4d supports feeding a mask and takes care of this if else no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but the mask is None here and then I need to pass all these shapes anyways

# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
if getattr(self.config, "_flash_attn_2_enabled", False):
encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask is not None and 0 in encoder_attention_mask) else None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think a comment would be nice to say why we don't pass the mask to FA if there is not 0 values in it.

Comment on lines +739 to +754
if getattr(config, "_flash_attn_2_enabled", False):
self.encoder_attn = BartFlashAttention2(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
else:
self.encoder_attn = BartAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think a mapping BERT_ATTENTIONS["attention_class"] will be cleaner long term if we add sdpa, flash decoding etc, specifically given that the init arguments are consistent (and we want this to always be the case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that makes sense!

@patrickvonplaten patrickvonplaten marked this pull request as draft October 26, 2023 15:17
@patrickvonplaten patrickvonplaten changed the title Add FA2 for all Bart-like [WIP] Add FA2 for all Bart-like Oct 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants