Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Attention Mask for Training #1516

Closed
Sanger2000 opened this issue Nov 6, 2023 · 3 comments
Closed

Enable Attention Mask for Training #1516

Sanger2000 opened this issue Nov 6, 2023 · 3 comments

Comments

@Sanger2000
Copy link

Sanger2000 commented Nov 6, 2023

Feature request

It appears that originally, attention masks were ignored for training because they used the slow path in pytorch's scaled dot product attention.

Am not fully confident, but I believe that they now support custom attention masks with memory efficient attention as per - pytorch/pytorch#104310.

It would be good to enable custom attention masks in BetterTransformer training.

Motivation

Want to throw in custom attention mask (for example fitting multiple examples in a given sequence, but only letting tokens pay attention to others in the same example.

Your contribution

It could be as straightforward as just removing the lines:

if self.is_training:
    attn_mask = None

In all implementations. I would be happy to do this. Perhaps it is also worth it to warn the user that memory-efficient attention will be used instead of flash attention.

@fxmarty
Copy link
Contributor

fxmarty commented Dec 13, 2023

Hi @Sanger2000 that's a good point. Just so you know, we are upstreaming SDPA support in Transformers directly & used by default, and you can already use it for a few models (see https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention), with good performances during training (see the benchmark at huggingface/transformers#28005).

I won't be putting too much effort in BetterTransformer (when it is only about using SDPA, not e.g. nested tensors, etc.), but rather extending the support of models with SDPA in Transformers.

@fxmarty fxmarty closed this as completed Dec 13, 2023
@rightaditya
Copy link

The attention mask is only supported for Memory-Efficient Attention, not FlashAttention. But even without the mask, on CUDA FlashAttention won't be used anyway because it requires fp16 or bf16 dtypes (on CUDA). You can test what kernel is being used enabling only one kernel at a time with the relevant functions in torch.backends.cuda (e.g., torch.backends.cuda.sdp_kernel).

@fxmarty Your point makes sense, but seeing as not all models are currently supported, and since this library is the recommended solution for models that aren't yet supported in the main Transformers library, it would help to allow the masks, at least if running on PyTorch v2.1+. Otherwise, would you be willing to accept a pull request that addresses this?

@fxmarty
Copy link
Contributor

fxmarty commented Mar 18, 2024

The attention mask is only supported for Memory-Efficient Attention, not FlashAttention. But even without the mask, on CUDA FlashAttention won't be used anyway because it requires fp16 or bf16 dtypes (on CUDA). You can test what kernel is being used enabling only one kernel at a time with the relevant functions in torch.backends.cuda (e.g., torch.backends.cuda.sdp_kernel).

Yes!

@fxmarty Your point makes sense, but seeing as not all models are currently supported, and since this library is the recommended solution for models that aren't yet supported in the main Transformers library, it would help to allow the masks, at least if running on PyTorch v2.1+. Otherwise, would you be willing to accept a pull request that addresses this?

Right, happy to review PRs. In the future I think PRs as huggingface/transformers#28802 are the way to go (although taking ages to be merged)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants